예제 #1
0
 def _default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss',
          BatchScoring(
              train_loss_score,
              name='train_loss',
              on_train=True,
              target_extractor=noop,
          )),
         ('valid_loss',
          BatchScoring(
              valid_loss_score,
              name='valid_loss',
              target_extractor=noop,
          )),
         ('print_log', PrintLog()),
     ]
예제 #2
0
def callbacks_monkey_patch(self):
    return [
        ('epoch_timer', EpochTimer()),
        ('train_loss', BatchScoring(
            train_loss_score,
            name='train_loss',
            on_train=True,
            target_extractor= \
                lambda x: self.encoder.transform(np.array(x).reshape(-1, 1)),
        )),
        ('valid_loss', BatchScoring(
            valid_loss_score,
            name='valid_loss',
            target_extractor= \
                lambda x: self.encoder.transform(np.array(x).reshape(-1, 1)),
        )),
        ('print_log', PrintLog()),
    ]
    def final_fit(self, dataset, best_params=None, train_all=False):
        if best_params is None:
            best_params = self.model.get_params()
        # Once the best parameters are found, the best model is trained on the whole training set.
        if not self.model.train_split is None and isinstance(self.model.train_split, CVSplit):
            best_params.update({"train_split": None if train_all else CVSplit(5)})
        # Callbacks
        train_acc = BatchScoring(scoring='accuracy', on_train=True,
                                 name='train_acc', lower_is_better=False)
        # Callbacks
        valid_acc = BatchScoring(scoring='accuracy', on_train=False,
                                 name='valid_acc', lower_is_better=False)
        best_params.update({"callbacks": [train_acc, valid_acc]})

        self.best_model.set_params(**best_params)
        self.best_model.fit(dataset, None)
        # saving
        with open(path.join(ARTIFACTS_DIR, self.name + '.pkl'), 'wb') as f:
            pickle.dump(self, f)
예제 #4
0
    def _default_callbacks(self):
        default_cb_list = [
            ('epoch_timer', EpochTimer()),
            ('train_loss',
             BatchScoring(train_loss_score,
                          name='train_loss',
                          on_train=True,
                          target_extractor=noop)),
            ('valid_loss',
             BatchScoring(valid_loss_score,
                          name='valid_loss',
                          target_extractor=noop)),
            ('valid_acc',
             EpochScoring(
                 'accuracy',
                 name='valid_acc',
                 lower_is_better=False,
             )),
            # ('checkpoint', Checkpoint(
            #     dirname=self.model_path)),
            # ('end_checkpoint', TrainEndCheckpoint(
            #     dirname=self.model_path)),
            ('report', ReportLog()),
            ('progressbar', ProgressBar())
        ]

        # if 'stop_patience' in self.hyperparamters.keys() and \
        #         self.hyperparamters['stop_patience']:
        #     earlystop_cb = ('earlystop',  EarlyStopping(
        #                     patience=self.patience,
        #                     threshold=1e-4))
        #     default_cb_list.append(earlystop_cb)
        #
        # if 'lr_step' in self.hyperparamters.keys() and \
        #         self.hyperparamters['lr_step']:
        #     lr_callback = ('lr_schedule', DecayLR(
        #                    self.hyperparamters['lr'],
        #                    self.hyperparamters['lr_step'],
        #                    gamma=0.5))
        #     default_cb_list.append(lr_callback)

        return default_cb_list
예제 #5
0
 def _default_callbacks(self):
     return [
         ("epoch_timer", EpochTimer()),
         (
             "train_loss",
             BatchScoring(
                 train_loss_score,
                 name="train_loss",
                 on_train=True,
                 target_extractor=noop,
             ),
         ),
         (
             "valid_loss",
             BatchScoring(
                 valid_loss_score, name="valid_loss", target_extractor=noop,
             ),
         ),
         ("print_log", PrintLog()),
     ]