def test_save(self, mock_save_check):
        state = {}
        check = MostRecent('test_file.pt')

        check.on_checkpoint(state)
        check.on_checkpoint(state)

        self.assertTrue(mock_save_check.call_count == 2)
예제 #2
0
    def test_save(self, mock_save_check):
        state = {}
        check = MostRecent('test_file.pt')

        check.on_end_epoch(state)
        check.on_end_epoch(state)

        self.assertTrue(mock_save_check.call_count == 2)
예제 #3
0
def train(args,
          model,
          model_loss,
          trainloader,
          valloader,
          epochs,
          name='model'):
    init_lr, sched = parse_learning_rate_arg(args.learning_rate)
    path = str(args.output)

    opt = build_optimiser(args, model, init_lr)
    callbacks = [
        Interval(filepath=path + '/' + name + '.{epoch:02d}.pt', period=10),
        MostRecent(filepath=path + '/' + name + '_final.pt'),
        CSVLogger(path + '/' + name + '-train-log.csv'), *sched
    ]

    metrics = ['loss', 'lr']
    if isinstance(model_loss, nn.CrossEntropyLoss):
        metrics.append('acc')

    trial = tb.Trial(model,
                     opt,
                     model_loss,
                     metrics=metrics,
                     callbacks=callbacks).to(args.device)
    trial.with_generators(train_generator=trainloader, val_generator=valloader)
    trial.run(epochs=epochs, verbose=2)
    return trial
예제 #4
0
def main():
    fake_parser = FakeArgumentParser(add_help=False, allow_abbrev=False)
    add_shared_args(fake_parser)
    fake_args, _ = fake_parser.parse_known_args()

    parser = argparse.ArgumentParser()
    add_shared_args(parser)
    add_sub_args(fake_args, parser)
    args = parser.parse_args()

    trainloader, valloader, testloader = build_dataloaders(args)

    args.output.mkdir(exist_ok=True, parents=True)
    path = str(args.output)
    save_args(args.output)

    model = get_model(args.model)()

    init_lr, sched = parse_learning_rate_arg(args.learning_rate)

    if args.optimiser == 'Adam':
        opt = optim.Adam(model.parameters(),
                         lr=init_lr,
                         weight_decay=args.weight_decay)
    else:
        opt = optim.SGD(model.parameters(),
                        lr=init_lr,
                        weight_decay=args.weight_decay,
                        momentum=args.momentum)

    callbacks = [
        Interval(filepath=path + '/model.{epoch:02d}.pt', period=10),
        MostRecent(filepath=path + '/model_final.pt'),
        CSVLogger(path + '/train-log.csv'), *sched
    ]

    trial = tb.Trial(model,
                     opt,
                     torch.nn.CrossEntropyLoss(),
                     metrics=['loss', 'acc', 'lr'],
                     callbacks=callbacks).to(args.device)
    trial.with_generators(train_generator=trainloader, val_generator=valloader)
    trial.run(epochs=args.epochs, verbose=2)

    trial = tb.Trial(model,
                     criterion=torch.nn.CrossEntropyLoss(),
                     metrics=['loss', 'acc'],
                     callbacks=[CSVLogger(path + '/test-log.csv')
                                ]).to(args.device)
    trial.with_generators(test_generator=testloader)
    trial.predict(verbose=2)