def build_trial(args):
    model = build_model(args)
    loss = build_loss(args)

    init_lr, sched = parse_learning_rate_arg(args.learning_rate)
    optim = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=args.weight_decay)

    inv = get_dataset(args.dataset).inv_transform

    callbacks = [
        Interval(filepath=str(args.output) + '/model_{epoch:02d}.pt', period=args.snapshot_interval),
        CSVLogger(str(args.output) + '/log.csv'),
        imaging.FromState(tb.Y_PRED, transform=inv).on_val().cache(args.num_reconstructions).make_grid().with_handler(
            img_to_file(str(args.output) + '/val_reconstruction_samples_{epoch:02d}.png')),
        imaging.FromState(tb.Y_TRUE, transform=inv).on_val().cache(args.num_reconstructions).make_grid().with_handler(
            img_to_file(str(args.output) + '/val_samples.png')),
        *model.get_callbacks(args),
        *sched,
    ]

    if args.variational:
        @torchbearer.callbacks.add_to_loss
        def add_kld_loss_callback(state):
            kl = torch.mean(0.5 * torch.sum(torch.exp(state[LOGVAR]) + state[MU] ** 2 - 1. - state[LOGVAR], 1))
            return kl

        callbacks.append(add_kld_loss_callback)

    trial = tb.Trial(model, optimizer=optim, criterion=loss, metrics=['loss', 'mse', 'lr'], callbacks=callbacks)
    trial.with_loader(autoenc_loader)

    return trial, model
Example #2
0
    def test_main(self):
        callback = imaging.FromState('test')

        self.assertTrue(
            callback.on_batch({
                torchbearer.EPOCH: 0,
                'test': 1
            }) == 1)
        self.assertTrue(
            callback.on_batch({
                torchbearer.EPOCH: 1,
                'testing': 1
            }) is None)
def build_test_trial(args):
    model = build_model(args).to(args.device)
    inv = get_dataset(args.dataset).inv_transform

    callbacks = [
        CSVLogger(str(args.output) + '/log.csv'),
        imaging.FromState(tb.Y_PRED, transform=inv).on_test().cache(args.num_reconstructions).make_grid().with_handler(
            img_to_file(str(args.output) + '/test_reconstruction_samples.png')),
        imaging.FromState(tb.Y_TRUE, transform=inv).on_test().cache(args.num_reconstructions).make_grid().with_handler(
            img_to_file(str(args.output) + '/test_samples.png'))
    ]

    metrics = ['mse'] #Additional metrics: ChamferMetric(), ModifiedHausdorffMetric()

    if args.classifier_weights:
        classifier = get_classifier_model(args.classifier_model)().to(args.device)
        state = torch.load(args.classifier_weights, map_location=args.device)
        classifier.load_state_dict(state[tb.MODEL])
        metrics.append(ClassificationMetric(classifier))

    trial = tb.Trial(model, metrics=metrics, callbacks=callbacks)
    trial.with_loader(autoenc_loader)

    return trial, model