def test_save(self, mock_save_check): state = {} check = MostRecent('test_file.pt') check.on_checkpoint(state) check.on_checkpoint(state) self.assertTrue(mock_save_check.call_count == 2)
def test_save(self, mock_save_check): state = {} check = MostRecent('test_file.pt') check.on_end_epoch(state) check.on_end_epoch(state) self.assertTrue(mock_save_check.call_count == 2)
def train(args, model, model_loss, trainloader, valloader, epochs, name='model'): init_lr, sched = parse_learning_rate_arg(args.learning_rate) path = str(args.output) opt = build_optimiser(args, model, init_lr) callbacks = [ Interval(filepath=path + '/' + name + '.{epoch:02d}.pt', period=10), MostRecent(filepath=path + '/' + name + '_final.pt'), CSVLogger(path + '/' + name + '-train-log.csv'), *sched ] metrics = ['loss', 'lr'] if isinstance(model_loss, nn.CrossEntropyLoss): metrics.append('acc') trial = tb.Trial(model, opt, model_loss, metrics=metrics, callbacks=callbacks).to(args.device) trial.with_generators(train_generator=trainloader, val_generator=valloader) trial.run(epochs=epochs, verbose=2) return trial
def main(): fake_parser = FakeArgumentParser(add_help=False, allow_abbrev=False) add_shared_args(fake_parser) fake_args, _ = fake_parser.parse_known_args() parser = argparse.ArgumentParser() add_shared_args(parser) add_sub_args(fake_args, parser) args = parser.parse_args() trainloader, valloader, testloader = build_dataloaders(args) args.output.mkdir(exist_ok=True, parents=True) path = str(args.output) save_args(args.output) model = get_model(args.model)() init_lr, sched = parse_learning_rate_arg(args.learning_rate) if args.optimiser == 'Adam': opt = optim.Adam(model.parameters(), lr=init_lr, weight_decay=args.weight_decay) else: opt = optim.SGD(model.parameters(), lr=init_lr, weight_decay=args.weight_decay, momentum=args.momentum) callbacks = [ Interval(filepath=path + '/model.{epoch:02d}.pt', period=10), MostRecent(filepath=path + '/model_final.pt'), CSVLogger(path + '/train-log.csv'), *sched ] trial = tb.Trial(model, opt, torch.nn.CrossEntropyLoss(), metrics=['loss', 'acc', 'lr'], callbacks=callbacks).to(args.device) trial.with_generators(train_generator=trainloader, val_generator=valloader) trial.run(epochs=args.epochs, verbose=2) trial = tb.Trial(model, criterion=torch.nn.CrossEntropyLoss(), metrics=['loss', 'acc'], callbacks=[CSVLogger(path + '/test-log.csv') ]).to(args.device) trial.with_generators(test_generator=testloader) trial.predict(verbose=2)