def build_trial(args): model = build_model(args) loss = build_loss(args) init_lr, sched = parse_learning_rate_arg(args.learning_rate) optim = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=args.weight_decay) inv = get_dataset(args.dataset).inv_transform callbacks = [ Interval(filepath=str(args.output) + '/model_{epoch:02d}.pt', period=args.snapshot_interval), CSVLogger(str(args.output) + '/log.csv'), imaging.FromState(tb.Y_PRED, transform=inv).on_val().cache(args.num_reconstructions).make_grid().with_handler( img_to_file(str(args.output) + '/val_reconstruction_samples_{epoch:02d}.png')), imaging.FromState(tb.Y_TRUE, transform=inv).on_val().cache(args.num_reconstructions).make_grid().with_handler( img_to_file(str(args.output) + '/val_samples.png')), *model.get_callbacks(args), *sched, ] if args.variational: @torchbearer.callbacks.add_to_loss def add_kld_loss_callback(state): kl = torch.mean(0.5 * torch.sum(torch.exp(state[LOGVAR]) + state[MU] ** 2 - 1. - state[LOGVAR], 1)) return kl callbacks.append(add_kld_loss_callback) trial = tb.Trial(model, optimizer=optim, criterion=loss, metrics=['loss', 'mse', 'lr'], callbacks=callbacks) trial.with_loader(autoenc_loader) return trial, model
def test_main(self): callback = imaging.FromState('test') self.assertTrue( callback.on_batch({ torchbearer.EPOCH: 0, 'test': 1 }) == 1) self.assertTrue( callback.on_batch({ torchbearer.EPOCH: 1, 'testing': 1 }) is None)
def build_test_trial(args): model = build_model(args).to(args.device) inv = get_dataset(args.dataset).inv_transform callbacks = [ CSVLogger(str(args.output) + '/log.csv'), imaging.FromState(tb.Y_PRED, transform=inv).on_test().cache(args.num_reconstructions).make_grid().with_handler( img_to_file(str(args.output) + '/test_reconstruction_samples.png')), imaging.FromState(tb.Y_TRUE, transform=inv).on_test().cache(args.num_reconstructions).make_grid().with_handler( img_to_file(str(args.output) + '/test_samples.png')) ] metrics = ['mse'] #Additional metrics: ChamferMetric(), ModifiedHausdorffMetric() if args.classifier_weights: classifier = get_classifier_model(args.classifier_model)().to(args.device) state = torch.load(args.classifier_weights, map_location=args.device) classifier.load_state_dict(state[tb.MODEL]) metrics.append(ClassificationMetric(classifier)) trial = tb.Trial(model, metrics=metrics, callbacks=callbacks) trial.with_loader(autoenc_loader) return trial, model