def build_trial(args):
    model = build_model(args)
    loss = build_loss(args)

    init_lr, sched = parse_learning_rate_arg(args.learning_rate)
    optim = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=args.weight_decay)

    inv = get_dataset(args.dataset).inv_transform

    callbacks = [
        Interval(filepath=str(args.output) + '/model_{epoch:02d}.pt', period=args.snapshot_interval),
        CSVLogger(str(args.output) + '/log.csv'),
        imaging.FromState(tb.Y_PRED, transform=inv).on_val().cache(args.num_reconstructions).make_grid().with_handler(
            img_to_file(str(args.output) + '/val_reconstruction_samples_{epoch:02d}.png')),
        imaging.FromState(tb.Y_TRUE, transform=inv).on_val().cache(args.num_reconstructions).make_grid().with_handler(
            img_to_file(str(args.output) + '/val_samples.png')),
        *model.get_callbacks(args),
        *sched,
    ]

    if args.variational:
        @torchbearer.callbacks.add_to_loss
        def add_kld_loss_callback(state):
            kl = torch.mean(0.5 * torch.sum(torch.exp(state[LOGVAR]) + state[MU] ** 2 - 1. - state[LOGVAR], 1))
            return kl

        callbacks.append(add_kld_loss_callback)

    trial = tb.Trial(model, optimizer=optim, criterion=loss, metrics=['loss', 'mse', 'lr'], callbacks=callbacks)
    trial.with_loader(autoenc_loader)

    return trial, model
Esempio n. 2
0
    def test_csv_closed(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log', write_header=False)
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.return_value.close.called)
Esempio n. 3
0
    def test_write_on_epoch(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log')
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertEqual(mock_write.call_count, 1)
Esempio n. 4
0
    def test_get_field_dict(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }
        correct_fields_dict = {
            'epoch': 0,
            'batch': 1,
            'test_metric_1': 0.1,
            'test_metric_2': 5
        }

        logger = CSVLogger('test_file.log', batch_granularity=True)

        logger_fields_dict = logger._get_field_dict(state)

        self.assertDictEqual(logger_fields_dict, correct_fields_dict)
Esempio n. 5
0
    def test_batch_granularity(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log', batch_granularity=True)
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_write.call_count == 3)
Esempio n. 6
0
    def test_get_field_dict(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }
        correct_fields_dict = {
            'epoch': 0,
            'batch': 1,
            'test_metric_1': 0.1,
            'test_metric_2': 5
        }

        logger = CSVLogger('test_file.log', batch_granularity=True)

        logger_fields_dict = logger._get_field_dict(state)

        self.assertDictEqual(logger_fields_dict, correct_fields_dict)
Esempio n. 7
0
    def test_write_header(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log')
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        header = mock_open.mock_calls[1][1][0]
        self.assertTrue('epoch' in header)
        self.assertTrue('test_metric_1' in header)
        self.assertTrue('test_metric_2' in header)
Esempio n. 8
0
    def test_append(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log', append=True)
        logger.on_start(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        import sys
        if sys.version_info[0] < 3:
            self.assertTrue(mock_open.call_args[0][1] == 'ab')
        else:
            self.assertTrue(mock_open.call_args[0][1] == 'a')
Esempio n. 9
0
    def test_append(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', append=True)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.call_args[0][1] == 'a+')
Esempio n. 10
0
    def test_write_on_epoch(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log')

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        mock_write.assert_called_once()
Esempio n. 11
0
    def test_append(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', append=True)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.call_args[0][1] == 'a+')
Esempio n. 12
0
    def test_csv_closed(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', write_header=False)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_open.return_value.close.called)
Esempio n. 13
0
    def test_write_no_header(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', write_header=False)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertFalse(mock_open.mock_calls[1][1][0] == 'epoch,test_metric_1,test_metric_2\r\n')
Esempio n. 14
0
    def test_write_no_header(self, mock_open):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {
                'test_metric_1': 0.1,
                'test_metric_2': 5
            }
        }

        logger = CSVLogger('test_file.log', write_header=False)

        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertFalse(mock_open.mock_calls[1][1][0] ==
                         'epoch,test_metric_1,test_metric_2\r\n')
Esempio n. 15
0
    def test_batch_granularity(self, mock_open, mock_write):
        state = {
            torchbearer.EPOCH: 0,
            torchbearer.BATCH: 1,
            torchbearer.METRICS: {'test_metric_1': 0.1, 'test_metric_2': 5}
        }

        logger = CSVLogger('test_file.log', batch_granularity=True)

        logger.on_step_training(state)
        logger.on_step_training(state)
        logger.on_end_epoch(state)
        logger.on_end(state)

        self.assertTrue(mock_write.call_count == 3)
def build_test_trial(args):
    model = build_model(args).to(args.device)
    inv = get_dataset(args.dataset).inv_transform

    callbacks = [
        CSVLogger(str(args.output) + '/log.csv'),
        imaging.FromState(tb.Y_PRED, transform=inv).on_test().cache(args.num_reconstructions).make_grid().with_handler(
            img_to_file(str(args.output) + '/test_reconstruction_samples.png')),
        imaging.FromState(tb.Y_TRUE, transform=inv).on_test().cache(args.num_reconstructions).make_grid().with_handler(
            img_to_file(str(args.output) + '/test_samples.png'))
    ]

    metrics = ['mse'] #Additional metrics: ChamferMetric(), ModifiedHausdorffMetric()

    if args.classifier_weights:
        classifier = get_classifier_model(args.classifier_model)().to(args.device)
        state = torch.load(args.classifier_weights, map_location=args.device)
        classifier.load_state_dict(state[tb.MODEL])
        metrics.append(ClassificationMetric(classifier))

    trial = tb.Trial(model, metrics=metrics, callbacks=callbacks)
    trial.with_loader(autoenc_loader)

    return trial, model