コード例 #1
0
ファイル: train.py プロジェクト: afcarl/it3105
 def set_up_trainer(self):
     self.trainer = bs.Trainer(
         bs.training.MomentumStepper(learning_rate=self.args.learning_rate,
                                     momentum=self.args.momentum))
     if self.args.progress_bar:
         self.trainer.add_hook(bs.hooks.ProgressBar())
     self.trainer.add_hook(
         bs.hooks.MonitorScores(
             'valid_getter',
             [bs.scorers.Accuracy(out_name='Output.outputs.predictions')],
             name='validation'))
     if not self.args.disable_saving:
         self.trainer.add_hook(
             bs.hooks.SaveBestNetwork('validation.Accuracy',
                                      filename=self.filename,
                                      name='best weights',
                                      criterion='max'))
     self.trainer.add_hook(
         bs.hooks.StopAfterThresholdReached(
             'validation.Accuracy',
             threshold=self.args.accuracy_threshold,
             criterion='max'))
     self.trainer.add_hook(
         bs.hooks.EarlyStopper('validation.Accuracy',
                               patience=self.args.patience,
                               criterion='max'))
     self.trainer.add_hook(bs.hooks.StopAfterEpoch(
         self.args.max_num_epochs))
コード例 #2
0
def test_describe_trainer():
    tr = bs.Trainer(bs.training.SgdStepper(learning_rate=0.7), verbose=False)
    tr.add_hook(bs.hooks.StopAfterEpoch(23))
    tr.add_hook(bs.hooks.StopOnNan())

    d = get_description(tr)
    assert d == {
        '@type': 'Trainer',
        'verbose': False,
        'hooks': {
            'StopAfterEpoch': {
                '@type': 'StopAfterEpoch',
                'max_epochs': 23,
                'priority': 1
            },
            'StopOnNan': {
                '@type': 'StopOnNan',
                'check_parameters': True,
                'check_training_loss': True,
                'logs_to_check': [],
                'priority': 2
            }
        },
        'stepper': {
            '@type': 'SgdStepper',
            'learning_rate': 0.7
        },
        'train_scorers': []
    }
コード例 #3
0
def create_trainer(training, net_filename, verbose):
    import os
    import os.path
    dirname = os.path.dirname(net_filename)
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    trainer = bs.Trainer(bs.training.SgdStepper(training['learning_rate']),
                         verbose=verbose)
    trainer.train_scorers = [bs.scorers.Hamming()]
    trainer.add_hook(bs.hooks.StopOnNan())
    trainer.add_hook(bs.hooks.StopAfterEpoch(training['max_epochs']))
    trainer.add_hook(
        bs.hooks.MonitorScores('val_iter',
                               trainer.train_scorers,
                               name='validation'))
    trainer.add_hook(
        bs.hooks.EarlyStopper('validation.total_loss',
                              patience=training['patience']))
    trainer.add_hook(
        bs.hooks.SaveBestNetwork('validation.total_loss',
                                 net_filename,
                                 criterion='min'))
    trainer.add_hook(bs.hooks.InfoUpdater(ex))
    if verbose:
        trainer.add_hook(bs.hooks.StopOnSigQuit())
        trainer.add_hook(bs.hooks.ProgressBar())
    return trainer
コード例 #4
0
def test_recreate_trainer_from_description():
    tr = bs.Trainer(bs.training.SgdStepper(learning_rate=0.7), verbose=False)
    tr.add_hook(bs.hooks.StopAfterEpoch(23))
    tr.add_hook(bs.hooks.StopOnNan())

    d = get_description(tr)

    tr2 = create_from_description(d)
    assert isinstance(tr2, bs.Trainer)
    assert tr2.verbose is False
    assert list(tr2.hooks.keys()) == ['StopAfterEpoch', 'StopOnNan']
    assert tr2.hooks['StopAfterEpoch'].max_epochs == 23
    assert isinstance(tr2.stepper, bs.training.SgdStepper)
    assert tr2.stepper.learning_rate == 0.7
コード例 #5
0
x_va, y_va = ds['validation']['default'][:], ds['validation']['targets'][:]

getter_tr = Minibatches(100, default=x_tr, targets=y_tr)
getter_va = Minibatches(100, default=x_va, targets=y_va)

# ----------------------------- Set up Network ------------------------------ #

inp, out = bs.tools.get_in_out_layers('classification', (28, 28, 1), 10)
network = bs.Network.from_layer(
    inp >> FullyConnected(500, name='Hid1', activation='linear') >> Square(
        name='MySquareLayer') >> out)

network.initialize(bs.initializers.Gaussian(0.01))

# ----------------------------- Set up Trainer ------------------------------ #

trainer = bs.Trainer(
    bs.training.MomentumStepper(learning_rate=0.01, momentum=0.9))
trainer.add_hook(bs.hooks.ProgressBar())
scorers = [bs.scorers.Accuracy(out_name='Output.outputs.probabilities')]
trainer.add_hook(
    bs.hooks.MonitorScores('valid_getter', scorers, name='validation'))
trainer.add_hook(bs.hooks.EarlyStopper('validation.Accuracy', patience=10))
trainer.add_hook(bs.hooks.StopAfterEpoch(500))

# -------------------------------- Train ------------------------------------ #

trainer.train(network, getter_tr, valid_getter=getter_va)
print("Best validation set accuracy:",
      max(trainer.logs["validation"]["Accuracy"]))
コード例 #6
0
                                        projection_name='FC')
network = brnst.Network.from_layer(inp >> brnst.layers.Dropout(
    drop_prob=0.2) >> brnst.layers.FullyConnected(
        1200, name='Hid1', activation='rel') >> brnst.layers.Dropout(
            drop_prob=0.5) >> brnst.layers.FullyConnected(
                1200, name='Hid2', activation='rel') >> brnst.layers.Dropout(
                    drop_prob=0.5) >> fc)

# Uncomment next line to use GPU
# network.set_handler(PyCudaHandler())
network.initialize(brnst.initializers.Gaussian(0.05))
network.set_weight_modifiers({"FC": brnst.value_modifiers.ConstrainL2Norm(1)})

# ----------------------------- Set up Trainer ------------------------------ #

trainer = brnst.Trainer(
    brnst.training.MomentumStepper(learning_rate=0.1, momentum=0.9))
trainer.add_hook(brnst.hooks.ProgressBar())
# trainer.add_hook(brnst.hooks.BokehVisualizer('validation.Accuracy'))
scorers = [brnst.scorers.Accuracy(out_name='Output.outputs.predictions')]
trainer.add_hook(
    brnst.hooks.MonitorScores('valid_getter', scorers, name='validation'))
trainer.add_hook(
    brnst.hooks.SaveBestNetwork('validation.Accuracy',
                                filename='mnist_pi_best500.hdf5',
                                name='best weights',
                                criterion='max'))
trainer.add_hook(brnst.hooks.StopAfterEpoch(10))

# -------------------------------- Train ------------------------------------ #

trainer.train(network, getter_tr, valid_getter=getter_va)