Пример #1
0
    def fit(self, train_generator=None, validation_generator=None, **kwargs):
        '''
        Pass through method to external model after running through pipeline

        Optionally overwrite normal method to pass in the generator directly. Used
        to speedup training by caching the transformed input before training the
        model - avoids downloading, reading, encoding images in every batch
        '''
        if self.pipeline is None:
            raise ModelError('Must set pipeline before fitting')

        if self.state['fitted']:
            LOGGER.warning('Cannot refit model, skipping operation')
            return self
        if train_generator is None:
            # Explicitly fit only on train split
            train_generator = self.pipeline.transform(
                X=None,
                dataset_split=TRAIN_SPLIT,
                return_y=True,
                infinite_loop=True,
                **self.get_params())
            validation_generator = self.pipeline.transform(
                X=None,
                dataset_split=VALIDATION_SPLIT,
                return_y=True,
                infinite_loop=True,
                **self.get_params())

        self._fit(train_generator, validation_generator)

        # Mark the state so it doesnt get refit and can now be saved
        self.state['fitted'] = True

        return self
Пример #2
0
    def predict(self, X, **kwargs):
        '''
        Pass through method to external model after running through pipeline
        '''
        if not self.state['fitted']:
            raise ModelError('Must fit model before predicting')

        transformed = self.pipeline.transform(X, **kwargs)

        return self.external_model.predict(transformed)
Пример #3
0
    def save(self, **kwargs):
        '''
        Extend parent function with a few additional save routines

        1) save params
        2) save feature metadata
        '''
        if self.pipeline is None:
            raise ModelError('Must set pipeline before saving')

        if not self.state['fitted']:
            raise ModelError('Must fit model before saving')

        self.params = self.get_params(**kwargs)
        self.feature_metadata = self.get_feature_metadata(**kwargs)

        super(BaseModel, self).save(**kwargs)

        # Sqlalchemy updates relationship references after save so reload class
        self.pipeline.load(load_externals=False)
Пример #4
0
    def fit(self, **kwargs):
        '''
        Pass through method to external model after running through pipeline
        '''
        if self.pipeline is None:
            raise ModelError('Must set pipeline before fitting')

        if self.state['fitted']:
            LOGGER.warning('Cannot refit model, skipping operation')
            return self

        # Explicitly fit only on train split
        X, y = self.pipeline.transform(X=None,
                                       dataset_split=TRAIN_SPLIT,
                                       return_y=True)
        # Reduce dimensionality of y if it is only 1 column
        self.external_model.fit(X, y.squeeze(), **kwargs)
        self.state['fitted'] = True

        return self
Пример #5
0
 def assert_fitted(self, msg=""):
     """
     Helper method to raise an error if model isn't fit
     """
     if not self.fitted:
         raise ModelError(msg)
Пример #6
0
 def assert_pipeline(self, msg=""):
     """
     Helper method to raise an error if pipeline isn't present and configured
     """
     if self.pipeline is None or not self.pipeline.fitted:
         raise ModelError(msg)
Пример #7
0
 def assert_fitted(self, msg=''):
     '''
     Helper method to raise an error if model isn't fit
     '''
     if not self.fitted:
         raise ModelError(msg)