Beispiel #1
0
    def test_print_skipped_if_not_verbose(self, capsys):
        from skorch.callbacks import PrintLog

        print_log = PrintLog().initialize()
        net = Mock(history=[{'loss': 123}], verbose=0)

        print_log.on_epoch_end(net)

        stdout = capsys.readouterr()[0]
        assert not stdout
Beispiel #2
0
    def test_print_skipped_if_not_verbose(self, capsys):
        from skorch.callbacks import PrintLog

        print_log = PrintLog().initialize()
        net = Mock(history=[{'loss': 123}], verbose=0)

        print_log.on_epoch_end(net)

        stdout = capsys.readouterr()[0]
        assert not stdout
Beispiel #3
0
    def test_args_passed_to_tabulate(self, history):
        with patch('skorch.callbacks.tabulate') as tab:
            from skorch.callbacks import PrintLog
            print_log = PrintLog(
                tablefmt='latex',
                floatfmt='.9f',
            ).initialize()
            print_log.table(history[-1])

            assert tab.call_count == 1
            assert tab.call_args_list[0][1]['tablefmt'] == 'latex'
            assert tab.call_args_list[0][1]['floatfmt'] == '.9f'
Beispiel #4
0
    def test_print_not_skipped_if_verbose(self, capsys):
        from skorch.callbacks import PrintLog

        print_log = PrintLog().initialize()
        net = Mock(history=[{'loss': 123}], verbose=1)

        print_log.on_epoch_end(net)

        stdout = capsys.readouterr()[0]
        result = [x.strip() for x in stdout.split()]
        expected = ['loss', '------', '123']
        assert result == expected
Beispiel #5
0
    def test_args_passed_to_tabulate(self, history):
        with patch('skorch.callbacks.logging.tabulate') as tab:
            from skorch.callbacks import PrintLog
            print_log = PrintLog(
                tablefmt='latex',
                floatfmt='.9f',
            ).initialize()
            print_log.table(history[-1])

            assert tab.call_count == 1
            assert tab.call_args_list[0][1]['tablefmt'] == 'latex'
            assert tab.call_args_list[0][1]['floatfmt'] == '.9f'
Beispiel #6
0
    def test_print_not_skipped_if_verbose(self, capsys):
        from skorch.callbacks import PrintLog

        print_log = PrintLog().initialize()
        net = Mock(history=[{'loss': 123}], verbose=1)

        print_log.on_epoch_end(net)

        stdout = capsys.readouterr()[0]
        result = [x.strip() for x in stdout.split()]
        expected = ['loss', '------', '123']
        assert result == expected
Beispiel #7
0
 def _default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss', PassthroughScoring(
             name='train_loss',
             on_train=True,
         )),
         ('discriminator_loss', PassthroughScoring(
             name='discriminator_loss',
             on_train=True
         )),
         ('generator_loss', PassthroughScoring(
             name='generator_loss',
             on_train=True
         )),
         ('adversarial_loss', PassthroughScoring(
             name='adversarial_loss',
             on_train=True
         )),
         ('contextual_loss', PassthroughScoring(
             name='contextual_loss',
             on_train=True
         )),
         ('encoder_loss', PassthroughScoring(
             name='encoder_loss',
             on_train=True
         )),
         ('print_log', PrintLog()),
     ]
Beispiel #8
0
 def _default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss',
          PassthroughScoring(
              name='train_loss',
              on_train=True,
          )),
         ('valid_loss', PassthroughScoring(name='valid_loss', )),
         # add train accuracy because by default, there is no valid split
         ('train_acc',
          EpochScoring(
              'accuracy',
              name='train_acc',
              lower_is_better=False,
              on_train=True,
          )),
         ('valid_acc',
          EpochScoring(
              'accuracy',
              name='valid_acc',
              lower_is_better=False,
          )),
         ('print_log', PrintLog()),
     ]
Beispiel #9
0
 def _default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('critic_loss',
          PassthroughScoring(
              name='critic_loss',
              lower_is_better=True,
              on_train=True,
          )),
         ('train_distance',
          PassthroughScoring(
              name='train_distance',
              lower_is_better=True,
              on_train=True,
          )),
         ('valid_distance',
          PassthroughScoring(
              name='valid_distance',
              lower_is_better=True,
          )),
         ('inter_distance',
          PassthroughScoring(
              name='inter_distance',
              lower_is_better=True,
          )),
         ('print_log', PrintLog()),
     ]
Beispiel #10
0
    def test_setting_callback_possible(self, net_cls, module_cls):
        from skorch.callbacks import EpochTimer, PrintLog

        net = net_cls(module_cls, callbacks=[('mycb', PrintLog())])
        net.initialize()

        assert isinstance(dict(net.callbacks_)['mycb'], PrintLog)

        net.set_params(callbacks__mycb=EpochTimer())
        assert isinstance(dict(net.callbacks_)['mycb'], EpochTimer)
Beispiel #11
0
 def get_default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss', Scoring(
             'train_loss',
             train_loss_score,
             on_train=True)),
         ('valid_loss', Scoring('valid_loss', valid_loss_score)),
         ('valid_acc', Scoring(
             name='valid_acc',
             scoring='accuracy_score',
             lower_is_better=False,
             on_train=False,
             pred_extractor=accuracy_pred_extractor,
         )),
         ('print_log', PrintLog()),
     ]
Beispiel #12
0
 def _default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss',
          BatchScoring(
              train_loss_score,
              name='train_loss',
              on_train=True,
              target_extractor=noop,
          )),
         ('valid_loss',
          BatchScoring(
              valid_loss_score,
              name='valid_loss',
              target_extractor=noop,
          )),
         ('print_log', PrintLog()),
     ]
Beispiel #13
0
def callbacks_monkey_patch(self):
    return [
        ('epoch_timer', EpochTimer()),
        ('train_loss', BatchScoring(
            train_loss_score,
            name='train_loss',
            on_train=True,
            target_extractor= \
                lambda x: self.encoder.transform(np.array(x).reshape(-1, 1)),
        )),
        ('valid_loss', BatchScoring(
            valid_loss_score,
            name='valid_loss',
            target_extractor= \
                lambda x: self.encoder.transform(np.array(x).reshape(-1, 1)),
        )),
        ('print_log', PrintLog()),
    ]
Beispiel #14
0
def train():

    disc = VGGNet()

    cp = SaveBestParam(dirname='best')
    early_stop = StopRestore(patience=10)
    score = Score_ConfusionMatrix(scoring="accuracy", lower_is_better=False)
    pt = PrintLog(keys_ignored="confusion_matrix")
    net = NeuralNetClassifier(disc,
                              max_epochs=100,
                              lr=0.01,
                              device='cuda',
                              callbacks=[('best', cp), ('early', early_stop)],
                              iterator_train__shuffle=True,
                              iterator_valid__shuffle=False)
    net.set_params(callbacks__valid_acc=score)
    net.set_params(callbacks__print_log=pt)

    # X, y = load_data()
    # net.fit(X, y)
    # print(1)

    param_dist = {
        'lr': [0.05, 0.01, 0.005],
    }

    search = RandomizedSearchCV(net,
                                param_dist,
                                cv=StratifiedKFold(n_splits=3),
                                n_iter=3,
                                verbose=10,
                                scoring='accuracy')

    X, y = load_data()

    # search.fit(X, y)

    Client("127.0.0.1:8786")  # create local cluster

    with joblib.parallel_backend('dask'):
        search.fit(X, y)

    with open('result.pkl', 'wb') as f:
        pickle.dump(search, f)
Beispiel #15
0
 def get_default_callbacks(self):
     return [
         ('epoch_timer', EpochTimer()),
         ('train_loss', EpochScoring(
             train_loss_score,
             name='train_loss',
             on_train=True,
         )),
         ('valid_loss', EpochScoring(
             valid_loss_score,
             name='valid_loss',
         )),
         ('valid_acc', EpochScoring(
             'accuracy',
             name='valid_acc',
             lower_is_better=False,
         )),
         ('print_log', PrintLog()),
     ]
    def hyperparameter_tunning(self):
        net = NeuralNet(network.SiameseNetV2,
                        max_epochs=2,
                        batch_size=128,
                        criterion=BCELoss,
                        optimizer=Adam,
                        iterator_train__num_workers=4,
                        iterator_train__pin_memory=False,
                        iterator_valid__num_workers=4,
                        verbose=2,
                        device='cuda',
                        iterator_train__shuffle=True,
                        callbacks=[PrintLog(), ProgressBar()])

        net.set_params(train_split=False)
        params = {'lr': [0.01, 0.001], 'module__num_dims': [128, 256]}
        gs = GridSearchCV(net, params, refit=False, cv=3, scoring='f1')
        X_sl = SliceDataset(self.train_set, idx=0)
        Y_sl = SliceDataset(self.train_set, idx=1)
        gs.fit(X_sl, Y_sl)
Beispiel #17
0
 def _default_callbacks(self):
     return [
         ("epoch_timer", EpochTimer()),
         (
             "train_loss",
             BatchScoring(
                 train_loss_score,
                 name="train_loss",
                 on_train=True,
                 target_extractor=noop,
             ),
         ),
         (
             "valid_loss",
             BatchScoring(
                 valid_loss_score, name="valid_loss", target_extractor=noop,
             ),
         ),
         ("print_log", PrintLog()),
     ]
Beispiel #18
0
def train():

    disc = VGGNet()

    cp = SaveBestParam(dirname='best')
    early_stop = StopRestore(patience=10000)
    nan_batch = CheckNanInf('accuracy', lower_is_better=False, on_train=True)
    score = Score_ConfusionMatrix(scoring="accuracy", lower_is_better=False)
    pt = PrintLog(keys_ignored="confusion_matrix")
    net = NeuralNetClassifier(disc,
                              max_epochs=10000,
                              lr=0.01,
                              device='cuda',
                              callbacks=[('best', cp), ('early', early_stop),
                                         ('nan', nan_batch)],
                              iterator_train__shuffle=True,
                              iterator_valid__shuffle=False)
    net.set_params(callbacks__valid_acc=score)
    net.set_params(callbacks__print_log=pt)

    X, y = load_data()
    net.fit(X, y)
    print(1)