def test_mean_metric_instance(self): metric = metrics.Metric('test') out = metrics.mean(metric) self.assertTrue(isinstance(out, metrics.MetricTree)) self.assertTrue(isinstance(out.children[0], metrics.ToDict)) self.assertTrue(isinstance(out.children[0].metric, metrics.Mean)) self.assertTrue(out.children[0].metric.name == 'test') self.assertTrue(out.root.name == 'test')
def test_mean_metric(self): metric = metrics.Metric out = metrics.mean(metric)('test').build() self.assertTrue(isinstance(out, metrics.MetricTree)) self.assertTrue(isinstance(out.children[0], metrics.ToDict)) self.assertTrue(isinstance(out.children[0].metric, metrics.Mean)) self.assertTrue(out.children[0].metric.name == 'test') self.assertTrue(out.root.name == 'test')
def test_mean_metric_class(self): metric = metrics.Metric out = metrics.mean(dim=10)(metric)('test') self.assertTrue(isinstance(out, metrics.MetricTree)) self.assertTrue(isinstance(out.children[0], metrics.ToDict)) self.assertTrue(isinstance(out.children[0].metric, metrics.Mean)) self.assertTrue(out.children[0].metric._kwargs['dim'] == 10) self.assertTrue(out.children[0].metric.name == 'test') self.assertTrue(out.root.name == 'test')
def test_mean_metric_factory(self, build_mock): metric = metrics.MetricFactory build_mock.return_value = metrics.Metric('test') out = metrics.mean(metric)().build() self.assertTrue(isinstance(out, metrics.MetricTree)) self.assertTrue(isinstance(out.children[0], metrics.ToDict)) self.assertTrue(isinstance(out.children[0].metric, metrics.Mean)) self.assertTrue(out.children[0].metric.name == 'test') self.assertTrue(out.root.name == 'test')
aug = [] mode = args.mode if mode == 'mix': aug = [callbacks.Mixup()] if mode == 'cutmix': aug = [callbacks.CutMix(1, classes=10)] if mode == 'fmix': aug = [FMix(alpha=1, decay_power=3)] model = VAE(64, var=args.var) trial = Trial(model, optim.Adam(model.parameters(), lr=5e-2), nll, metrics=[ metrics.MeanSquaredError(pred_key=SAMPLE), metrics.mean(NLL), metrics.mean(KL), 'loss' ], callbacks=[ sample, kld(distributions.Normal(0, 1)), init.XavierNormal(targets=['Conv']), callbacks.MostRecent(args.dir + '/' + mode + '_' + str(args.i) + '.pt'), callbacks.MultiStepLR([40, 80]), callbacks.TensorBoard(write_graph=False, comment=mode + '_' + str(args.i), log_dir='vae_logs'), *aug ]) if mode in ['base', 'mix', 'cutmix']:
closure_gen = base_closure(tb.X, tb.MODEL, tb.Y_PRED, tb.Y_TRUE, tb.CRITERION, tb.LOSS, GEN_OPT) closure_disc = base_closure(tb.Y_PRED, DISC_MODEL, None, DISC_IMGS, DISC_CRIT, tb.LOSS, DISC_OPT) def closure(state): closure_gen(state) state[GEN_OPT].step() closure_disc(state) state[DISC_OPT].step() from torchbearer.metrics import mean, running_mean metrics = ['loss', mean(running_mean(D_LOSS)), mean(running_mean(G_LOSS))] trial = tb.Trial(generator, None, criterion=gen_crit, metrics=metrics, callbacks=[saver_callback]) trial.with_train_generator(dataloader, steps=200000) trial.to(device) new_keys = { DISC_MODEL: discriminator.to(device), DISC_OPT: optimizer_D, GEN_OPT: optimizer_G, DISC_CRIT: disc_crit }