def test_state_dict(self): check = Interval('test') check.epochs_since_last_save = 10 state = check.state_dict() check = Interval('test') check.load_state_dict(state) self.assertEqual(check.epochs_since_last_save, 10)
def train(args, model, model_loss, trainloader, valloader, epochs, name='model'): init_lr, sched = parse_learning_rate_arg(args.learning_rate) path = str(args.output) opt = build_optimiser(args, model, init_lr) callbacks = [ Interval(filepath=path + '/' + name + '.{epoch:02d}.pt', period=10), MostRecent(filepath=path + '/' + name + '_final.pt'), CSVLogger(path + '/' + name + '-train-log.csv'), *sched ] metrics = ['loss', 'lr'] if isinstance(model_loss, nn.CrossEntropyLoss): metrics.append('acc') trial = tb.Trial(model, opt, model_loss, metrics=metrics, callbacks=callbacks).to(args.device) trial.with_generators(train_generator=trainloader, val_generator=valloader) trial.run(epochs=epochs, verbose=2) return trial
def test_interval_is_1(self, mock_save_check): state = {} check = Interval('test_file', period=1) check.on_checkpoint(state) check.on_checkpoint(state) self.assertTrue(mock_save_check.call_count == 2)
def main(): fake_parser = FakeArgumentParser(add_help=False, allow_abbrev=False) add_shared_args(fake_parser) fake_args, _ = fake_parser.parse_known_args() parser = argparse.ArgumentParser() add_shared_args(parser) add_sub_args(fake_args, parser) args = parser.parse_args() trainloader, valloader, testloader = build_dataloaders(args) args.output.mkdir(exist_ok=True, parents=True) path = str(args.output) save_args(args.output) model = get_model(args.model)() init_lr, sched = parse_learning_rate_arg(args.learning_rate) if args.optimiser == 'Adam': opt = optim.Adam(model.parameters(), lr=init_lr, weight_decay=args.weight_decay) else: opt = optim.SGD(model.parameters(), lr=init_lr, weight_decay=args.weight_decay, momentum=args.momentum) callbacks = [ Interval(filepath=path + '/model.{epoch:02d}.pt', period=10), MostRecent(filepath=path + '/model_final.pt'), CSVLogger(path + '/train-log.csv'), *sched ] trial = tb.Trial(model, opt, torch.nn.CrossEntropyLoss(), metrics=['loss', 'acc', 'lr'], callbacks=callbacks).to(args.device) trial.with_generators(train_generator=trainloader, val_generator=valloader) trial.run(epochs=args.epochs, verbose=2) trial = tb.Trial(model, criterion=torch.nn.CrossEntropyLoss(), metrics=['loss', 'acc'], callbacks=[CSVLogger(path + '/test-log.csv') ]).to(args.device) trial.with_generators(test_generator=testloader) trial.predict(verbose=2)
def test_interval_on_batch(self, mock_save_check): state = {} check = Interval('test_file', period=4, on_batch=True) for i in range(13): check.on_step_training(state) if i == 3: self.assertTrue(mock_save_check.call_count == 1) elif i == 6: self.assertFalse(mock_save_check.call_count == 2) elif i == 7: self.assertTrue(mock_save_check.call_count == 2) check.on_checkpoint(state) self.assertTrue(mock_save_check.call_count == 3)
def test_interval_is_more_than_1(self, mock_save_check): state = {} check = Interval('test_file', period=4) for i in range(13): check.on_checkpoint(state) if i == 3: self.assertTrue(mock_save_check.call_count == 1) elif i == 6: self.assertFalse(mock_save_check.call_count == 2) elif i == 7: self.assertTrue(mock_save_check.call_count == 2) self.assertTrue(mock_save_check.call_count == 3)