示例#1
0
 def get_default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss',
          EpochScoring(
              train_loss_score,
              name='train_loss',
              on_train=True,
          )),
         ('valid_loss', EpochScoring(
             valid_loss_score,
             name='valid_loss',
         )),
         ('valid_acc',
          EpochScoring(
              'accuracy',
              name='valid_acc',
              lower_is_better=False,
          )),
         ('print_log', PrintLog(sink=logger_info,
                                keys_ignored=[
                                    'batches',
                                ])),
         ('auc', EpochScoring(scoring='roc_auc', lower_is_better=False))
     ]
示例#2
0
文件: gan.py 项目: bdura/bestiary
 def _default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('critic_loss',
          PassthroughScoring(
              name='critic_loss',
              lower_is_better=True,
              on_train=True,
          )),
         ('train_distance',
          PassthroughScoring(
              name='train_distance',
              lower_is_better=True,
              on_train=True,
          )),
         ('valid_distance',
          PassthroughScoring(
              name='valid_distance',
              lower_is_better=True,
          )),
         ('inter_distance',
          PassthroughScoring(
              name='inter_distance',
              lower_is_better=True,
          )),
         ('print_log', PrintLog()),
     ]
示例#3
0
 def _default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss',
          PassthroughScoring(
              name='train_loss',
              on_train=True,
          )),
         ('valid_loss', PassthroughScoring(name='valid_loss', )),
         # add train accuracy because by default, there is no valid split
         ('train_acc',
          EpochScoring(
              'accuracy',
              name='train_acc',
              lower_is_better=False,
              on_train=True,
          )),
         ('valid_acc',
          EpochScoring(
              'accuracy',
              name='valid_acc',
              lower_is_better=False,
          )),
         ('print_log', PrintLog()),
     ]
示例#4
0
 def _default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss', PassthroughScoring(
             name='train_loss',
             on_train=True,
         )),
         ('discriminator_loss', PassthroughScoring(
             name='discriminator_loss',
             on_train=True
         )),
         ('generator_loss', PassthroughScoring(
             name='generator_loss',
             on_train=True
         )),
         ('adversarial_loss', PassthroughScoring(
             name='adversarial_loss',
             on_train=True
         )),
         ('contextual_loss', PassthroughScoring(
             name='contextual_loss',
             on_train=True
         )),
         ('encoder_loss', PassthroughScoring(
             name='encoder_loss',
             on_train=True
         )),
         ('print_log', PrintLog()),
     ]
示例#5
0
    def test_setting_callback_default_possible(self, net_cls, module_cls):
        from skorch.callbacks import EpochTimer, PrintLog

        net = net_cls(module_cls)
        net.initialize()

        assert isinstance(dict(net.callbacks_)['print_log'], PrintLog)

        net.set_params(callbacks__print_log=EpochTimer())
        assert isinstance(dict(net.callbacks_)['print_log'], EpochTimer)
示例#6
0
    def test_setting_callback_possible(self, net_cls, module_cls):
        from skorch.callbacks import EpochTimer, PrintLog

        net = net_cls(module_cls, callbacks=[('mycb', PrintLog())])
        net.initialize()

        assert isinstance(dict(net.callbacks_)['mycb'], PrintLog)

        net.set_params(callbacks__mycb=EpochTimer())
        assert isinstance(dict(net.callbacks_)['mycb'], EpochTimer)
示例#7
0
 def get_default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss', Scoring(
             'train_loss',
             train_loss_score,
             on_train=True)),
         ('valid_loss', Scoring('valid_loss', valid_loss_score)),
         ('valid_acc', Scoring(
             name='valid_acc',
             scoring='accuracy_score',
             lower_is_better=False,
             on_train=False,
             pred_extractor=accuracy_pred_extractor,
         )),
         ('print_log', PrintLog()),
     ]
示例#8
0
 def _default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss',
          BatchScoring(
              train_loss_score,
              name='train_loss',
              on_train=True,
              target_extractor=noop,
          )),
         ('valid_loss',
          BatchScoring(
              valid_loss_score,
              name='valid_loss',
              target_extractor=noop,
          )),
         ('print_log', PrintLog()),
     ]
示例#9
0
def callbacks_monkey_patch(self):
    return [
        ('epoch_timer', EpochTimer()),
        ('train_loss', BatchScoring(
            train_loss_score,
            name='train_loss',
            on_train=True,
            target_extractor= \
                lambda x: self.encoder.transform(np.array(x).reshape(-1, 1)),
        )),
        ('valid_loss', BatchScoring(
            valid_loss_score,
            name='valid_loss',
            target_extractor= \
                lambda x: self.encoder.transform(np.array(x).reshape(-1, 1)),
        )),
        ('print_log', PrintLog()),
    ]
示例#10
0
文件: net.py 项目: rain1024/skorch
 def get_default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss', EpochScoring(
             train_loss_score,
             name='train_loss',
             on_train=True,
         )),
         ('valid_loss', EpochScoring(
             valid_loss_score,
             name='valid_loss',
         )),
         ('valid_acc', EpochScoring(
             'accuracy',
             name='valid_acc',
             lower_is_better=False,
         )),
         ('print_log', PrintLog()),
     ]
示例#11
0
    def _default_callbacks(self):
        default_cb_list = [
            ('epoch_timer', EpochTimer()),
            ('train_loss',
             BatchScoring(train_loss_score,
                          name='train_loss',
                          on_train=True,
                          target_extractor=noop)),
            ('valid_loss',
             BatchScoring(valid_loss_score,
                          name='valid_loss',
                          target_extractor=noop)),
            ('valid_acc',
             EpochScoring(
                 'accuracy',
                 name='valid_acc',
                 lower_is_better=False,
             )),
            # ('checkpoint', Checkpoint(
            #     dirname=self.model_path)),
            # ('end_checkpoint', TrainEndCheckpoint(
            #     dirname=self.model_path)),
            ('report', ReportLog()),
            ('progressbar', ProgressBar())
        ]

        # if 'stop_patience' in self.hyperparamters.keys() and \
        #         self.hyperparamters['stop_patience']:
        #     earlystop_cb = ('earlystop',  EarlyStopping(
        #                     patience=self.patience,
        #                     threshold=1e-4))
        #     default_cb_list.append(earlystop_cb)
        #
        # if 'lr_step' in self.hyperparamters.keys() and \
        #         self.hyperparamters['lr_step']:
        #     lr_callback = ('lr_schedule', DecayLR(
        #                    self.hyperparamters['lr'],
        #                    self.hyperparamters['lr_step'],
        #                    gamma=0.5))
        #     default_cb_list.append(lr_callback)

        return default_cb_list
示例#12
0
 def _default_callbacks(self):
     return [
         ("epoch_timer", EpochTimer()),
         (
             "train_loss",
             BatchScoring(
                 train_loss_score,
                 name="train_loss",
                 on_train=True,
                 target_extractor=noop,
             ),
         ),
         (
             "valid_loss",
             BatchScoring(
                 valid_loss_score, name="valid_loss", target_extractor=noop,
             ),
         ),
         ("print_log", PrintLog()),
     ]