""" Contains models for sci-kit learn estimators """
try:
from sklearn.external import joblib as pickle
except ImportError:
pass
try:
import dill as pickle
except ImportError:
pass
from .base import BaseModel
[docs]class SklearnModel(BaseModel):
""" Base class for scikit-learn models
Attributes
----------
estimator
an instance of scikit-learn estimator
Notes
-----
**Configuration**
estimator - an instance of scikit-learn estimator
load / path - a path to a pickled estimator
Examples
--------
.. code-block:: python
pipeline
.init_model('static', SklearnModel, 'my_model',
config={'estimator': sklearn.linear_model.SGDClassifier(loss='huber')})
pipeline
.init_model('static', SklearnModel, 'my_model',
config={'load/path': '/path/to/estimator.pickle'})
"""
def __init__(self, *args, **kwargs):
self.estimator = None
super().__init__(*args, **kwargs)
[docs] def build(self, *args, **kwargs):
""" Define the model """
_ = args, kwargs
self.estimator = self.config.get('estimator')
[docs] def reset(self):
""" Reset the trained model to allow a new training from scratch """
self.build()
[docs] def load(self, path):
""" Load the model.
Parameters
----------
path : str
a full path to a file from which a model will be loaded
"""
self.estimator = pickle.load(path)
[docs] def save(self, path):
""" Save the model.
Parameters
----------
path : str
a full path to a file where a model will be saved to
"""
if self.estimator is not None:
pickle.dump(self.estimator, path)
else:
raise ValueError("Scikit-learn estimator does not exist. Check your config for 'estimator'.")
[docs] def train(self, x, y, *args, **kwargs):
""" Train the model with the data provided
Parameters
----------
X : array-like
Subset of the training data, shape (n_samples, n_features)
y : numpy array
Subset of the target values, shape (n_samples,)
Notes
-----
For more details and other parameters look at the documentation for the estimator used.
"""
if hasattr(self.estimator, 'partial_fit'):
self.estimator.partial_fit(x, y, *args, **kwargs)
else:
self.estimator.fit(x, y, *args, **kwargs)
[docs] def predict(self, x, *args, **kwargs):
""" Predict with the data provided
Parameters
----------
X : array-like
Subset of the training data, shape (n_samples, n_features)
Notes
-----
For more details and other parameters look at the documentation for the estimator used.
Returns
-------
array
Predicted value per sample, shape (n_samples,)
"""
return self.estimator.predict(x, *args, **kwargs)