def build_trial(args): model = build_model(args) loss = build_loss(args) init_lr, sched = parse_learning_rate_arg(args.learning_rate) optim = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=args.weight_decay) inv = get_dataset(args.dataset).inv_transform callbacks = [ Interval(filepath=str(args.output) + '/model_{epoch:02d}.pt', period=args.snapshot_interval), CSVLogger(str(args.output) + '/log.csv'), imaging.FromState(tb.Y_PRED, transform=inv).on_val().cache(args.num_reconstructions).make_grid().with_handler( img_to_file(str(args.output) + '/val_reconstruction_samples_{epoch:02d}.png')), imaging.FromState(tb.Y_TRUE, transform=inv).on_val().cache(args.num_reconstructions).make_grid().with_handler( img_to_file(str(args.output) + '/val_samples.png')), *model.get_callbacks(args), *sched, ] if args.variational: @torchbearer.callbacks.add_to_loss def add_kld_loss_callback(state): kl = torch.mean(0.5 * torch.sum(torch.exp(state[LOGVAR]) + state[MU] ** 2 - 1. - state[LOGVAR], 1)) return kl callbacks.append(add_kld_loss_callback) trial = tb.Trial(model, optimizer=optim, criterion=loss, metrics=['loss', 'mse', 'lr'], callbacks=callbacks) trial.with_loader(autoenc_loader) return trial, model
def test_csv_closed(self, mock_open): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: { 'test_metric_1': 0.1, 'test_metric_2': 5 } } logger = CSVLogger('test_file.log', write_header=False) logger.on_start(state) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) self.assertTrue(mock_open.return_value.close.called)
def test_write_on_epoch(self, mock_open, mock_write): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: { 'test_metric_1': 0.1, 'test_metric_2': 5 } } logger = CSVLogger('test_file.log') logger.on_start(state) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) self.assertEqual(mock_write.call_count, 1)
def test_get_field_dict(self, mock_open): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5} } correct_fields_dict = { 'epoch': 0, 'batch': 1, 'test_metric_1': 0.1, 'test_metric_2': 5 } logger = CSVLogger('test_file.log', batch_granularity=True) logger_fields_dict = logger._get_field_dict(state) self.assertDictEqual(logger_fields_dict, correct_fields_dict)
def test_batch_granularity(self, mock_open, mock_write): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: { 'test_metric_1': 0.1, 'test_metric_2': 5 } } logger = CSVLogger('test_file.log', batch_granularity=True) logger.on_start(state) logger.on_step_training(state) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) self.assertTrue(mock_write.call_count == 3)
def test_write_header(self, mock_open): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: { 'test_metric_1': 0.1, 'test_metric_2': 5 } } logger = CSVLogger('test_file.log') logger.on_start(state) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) header = mock_open.mock_calls[1][1][0] self.assertTrue('epoch' in header) self.assertTrue('test_metric_1' in header) self.assertTrue('test_metric_2' in header)
def test_append(self, mock_open): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: { 'test_metric_1': 0.1, 'test_metric_2': 5 } } logger = CSVLogger('test_file.log', append=True) logger.on_start(state) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) import sys if sys.version_info[0] < 3: self.assertTrue(mock_open.call_args[0][1] == 'ab') else: self.assertTrue(mock_open.call_args[0][1] == 'a')
def test_append(self, mock_open): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5} } logger = CSVLogger('test_file.log', append=True) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) self.assertTrue(mock_open.call_args[0][1] == 'a+')
def test_write_on_epoch(self, mock_open, mock_write): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5} } logger = CSVLogger('test_file.log') logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) mock_write.assert_called_once()
def test_csv_closed(self, mock_open): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5} } logger = CSVLogger('test_file.log', write_header=False) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) self.assertTrue(mock_open.return_value.close.called)
def test_write_no_header(self, mock_open): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5} } logger = CSVLogger('test_file.log', write_header=False) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) self.assertFalse(mock_open.mock_calls[1][1][0] == 'epoch,test_metric_1,test_metric_2\r\n')
def test_write_no_header(self, mock_open): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: { 'test_metric_1': 0.1, 'test_metric_2': 5 } } logger = CSVLogger('test_file.log', write_header=False) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) self.assertFalse(mock_open.mock_calls[1][1][0] == 'epoch,test_metric_1,test_metric_2\r\n')
def test_batch_granularity(self, mock_open, mock_write): state = { torchbearer.EPOCH: 0, torchbearer.BATCH: 1, torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5} } logger = CSVLogger('test_file.log', batch_granularity=True) logger.on_step_training(state) logger.on_step_training(state) logger.on_end_epoch(state) logger.on_end(state) self.assertTrue(mock_write.call_count == 3)
def build_test_trial(args): model = build_model(args).to(args.device) inv = get_dataset(args.dataset).inv_transform callbacks = [ CSVLogger(str(args.output) + '/log.csv'), imaging.FromState(tb.Y_PRED, transform=inv).on_test().cache(args.num_reconstructions).make_grid().with_handler( img_to_file(str(args.output) + '/test_reconstruction_samples.png')), imaging.FromState(tb.Y_TRUE, transform=inv).on_test().cache(args.num_reconstructions).make_grid().with_handler( img_to_file(str(args.output) + '/test_samples.png')) ] metrics = ['mse'] #Additional metrics: ChamferMetric(), ModifiedHausdorffMetric() if args.classifier_weights: classifier = get_classifier_model(args.classifier_model)().to(args.device) state = torch.load(args.classifier_weights, map_location=args.device) classifier.load_state_dict(state[tb.MODEL]) metrics.append(ClassificationMetric(classifier)) trial = tb.Trial(model, metrics=metrics, callbacks=callbacks) trial.with_loader(autoenc_loader) return trial, model