def test_interval_is_1(self, mock_save_check):
        state = {}
        check = Interval('test_file', period=1)

        check.on_checkpoint(state)
        check.on_checkpoint(state)

        self.assertTrue(mock_save_check.call_count == 2)
Exemplo n.º 2
0
    def test_interval_is_1(self, mock_save_check):
        state = {}
        check = Interval('test_file', period=1)

        check.on_end_epoch(state)
        check.on_end_epoch(state)

        self.assertTrue(mock_save_check.call_count == 2)
    def test_interval_is_more_than_1(self, mock_save_check):
        state = {}
        check = Interval('test_file', period=4)

        for i in range(13):
            check.on_checkpoint(state)
            if i == 3:
                self.assertTrue(mock_save_check.call_count == 1)
            elif i == 6:
                self.assertFalse(mock_save_check.call_count == 2)
            elif i == 7:
                self.assertTrue(mock_save_check.call_count == 2)

        self.assertTrue(mock_save_check.call_count == 3)
Exemplo n.º 4
0
    def test_interval_is_more_than_1(self, mock_save_check):
        state = {}
        check = Interval('test_file', period=4)

        for i in range(13):
            check.on_end_epoch(state)
            if i == 3:
                self.assertTrue(mock_save_check.call_count == 1)
            elif i == 6:
                self.assertFalse(mock_save_check.call_count == 2)
            elif i == 7:
                self.assertTrue(mock_save_check.call_count == 2)

        self.assertTrue(mock_save_check.call_count == 3)
Exemplo n.º 5
0
    def test_state_dict(self):
        check = Interval('test')
        check.most_recent = 'temp'
        check.epochs_since_last_save = 10

        state = check.state_dict()

        check = Interval('test')
        check.load_state_dict(state)

        self.assertEqual(check.most_recent, 'temp')
        self.assertEqual(check.epochs_since_last_save, 10)
Exemplo n.º 6
0
def train(args,
          model,
          model_loss,
          trainloader,
          valloader,
          epochs,
          name='model'):
    init_lr, sched = parse_learning_rate_arg(args.learning_rate)
    path = str(args.output)

    opt = build_optimiser(args, model, init_lr)
    callbacks = [
        Interval(filepath=path + '/' + name + '.{epoch:02d}.pt', period=10),
        MostRecent(filepath=path + '/' + name + '_final.pt'),
        CSVLogger(path + '/' + name + '-train-log.csv'), *sched
    ]

    metrics = ['loss', 'lr']
    if isinstance(model_loss, nn.CrossEntropyLoss):
        metrics.append('acc')

    trial = tb.Trial(model,
                     opt,
                     model_loss,
                     metrics=metrics,
                     callbacks=callbacks).to(args.device)
    trial.with_generators(train_generator=trainloader, val_generator=valloader)
    trial.run(epochs=epochs, verbose=2)
    return trial
Exemplo n.º 7
0
def main():
    fake_parser = FakeArgumentParser(add_help=False, allow_abbrev=False)
    add_shared_args(fake_parser)
    fake_args, _ = fake_parser.parse_known_args()

    parser = argparse.ArgumentParser()
    add_shared_args(parser)
    add_sub_args(fake_args, parser)
    args = parser.parse_args()

    trainloader, valloader, testloader = build_dataloaders(args)

    args.output.mkdir(exist_ok=True, parents=True)
    path = str(args.output)
    save_args(args.output)

    model = get_model(args.model)()

    init_lr, sched = parse_learning_rate_arg(args.learning_rate)

    if args.optimiser == 'Adam':
        opt = optim.Adam(model.parameters(),
                         lr=init_lr,
                         weight_decay=args.weight_decay)
    else:
        opt = optim.SGD(model.parameters(),
                        lr=init_lr,
                        weight_decay=args.weight_decay,
                        momentum=args.momentum)

    callbacks = [
        Interval(filepath=path + '/model.{epoch:02d}.pt', period=10),
        MostRecent(filepath=path + '/model_final.pt'),
        CSVLogger(path + '/train-log.csv'), *sched
    ]

    trial = tb.Trial(model,
                     opt,
                     torch.nn.CrossEntropyLoss(),
                     metrics=['loss', 'acc', 'lr'],
                     callbacks=callbacks).to(args.device)
    trial.with_generators(train_generator=trainloader, val_generator=valloader)
    trial.run(epochs=args.epochs, verbose=2)

    trial = tb.Trial(model,
                     criterion=torch.nn.CrossEntropyLoss(),
                     metrics=['loss', 'acc'],
                     callbacks=[CSVLogger(path + '/test-log.csv')
                                ]).to(args.device)
    trial.with_generators(test_generator=testloader)
    trial.predict(verbose=2)
    def test_interval_on_batch(self, mock_save_check):
        state = {}
        check = Interval('test_file', period=4, on_batch=True)

        for i in range(13):
            check.on_step_training(state)
            if i == 3:
                self.assertTrue(mock_save_check.call_count == 1)
            elif i == 6:
                self.assertFalse(mock_save_check.call_count == 2)
            elif i == 7:
                self.assertTrue(mock_save_check.call_count == 2)
        check.on_checkpoint(state)
        self.assertTrue(mock_save_check.call_count == 3)