Exemplo n.º 1
0
 def test_mean_metric_instance(self):
     metric = metrics.Metric('test')
     out = metrics.mean(metric)
     self.assertTrue(isinstance(out, metrics.MetricTree))
     self.assertTrue(isinstance(out.children[0], metrics.ToDict))
     self.assertTrue(isinstance(out.children[0].metric, metrics.Mean))
     self.assertTrue(out.children[0].metric.name == 'test')
     self.assertTrue(out.root.name == 'test')
Exemplo n.º 2
0
 def test_mean_metric(self):
     metric = metrics.Metric
     out = metrics.mean(metric)('test').build()
     self.assertTrue(isinstance(out, metrics.MetricTree))
     self.assertTrue(isinstance(out.children[0], metrics.ToDict))
     self.assertTrue(isinstance(out.children[0].metric, metrics.Mean))
     self.assertTrue(out.children[0].metric.name == 'test')
     self.assertTrue(out.root.name == 'test')
Exemplo n.º 3
0
 def test_mean_metric_class(self):
     metric = metrics.Metric
     out = metrics.mean(dim=10)(metric)('test')
     self.assertTrue(isinstance(out, metrics.MetricTree))
     self.assertTrue(isinstance(out.children[0], metrics.ToDict))
     self.assertTrue(isinstance(out.children[0].metric, metrics.Mean))
     self.assertTrue(out.children[0].metric._kwargs['dim'] == 10)
     self.assertTrue(out.children[0].metric.name == 'test')
     self.assertTrue(out.root.name == 'test')
Exemplo n.º 4
0
 def test_mean_metric_factory(self, build_mock):
     metric = metrics.MetricFactory
     build_mock.return_value = metrics.Metric('test')
     out = metrics.mean(metric)().build()
     self.assertTrue(isinstance(out, metrics.MetricTree))
     self.assertTrue(isinstance(out.children[0], metrics.ToDict))
     self.assertTrue(isinstance(out.children[0].metric, metrics.Mean))
     self.assertTrue(out.children[0].metric.name == 'test')
     self.assertTrue(out.root.name == 'test')
Exemplo n.º 5
0
    aug = []
    mode = args.mode
    if mode == 'mix':
        aug = [callbacks.Mixup()]
    if mode == 'cutmix':
        aug = [callbacks.CutMix(1, classes=10)]
    if mode == 'fmix':
        aug = [FMix(alpha=1, decay_power=3)]

    model = VAE(64, var=args.var)
    trial = Trial(model,
                  optim.Adam(model.parameters(), lr=5e-2),
                  nll,
                  metrics=[
                      metrics.MeanSquaredError(pred_key=SAMPLE),
                      metrics.mean(NLL),
                      metrics.mean(KL), 'loss'
                  ],
                  callbacks=[
                      sample,
                      kld(distributions.Normal(0, 1)),
                      init.XavierNormal(targets=['Conv']),
                      callbacks.MostRecent(args.dir + '/' + mode + '_' +
                                           str(args.i) + '.pt'),
                      callbacks.MultiStepLR([40, 80]),
                      callbacks.TensorBoard(write_graph=False,
                                            comment=mode + '_' + str(args.i),
                                            log_dir='vae_logs'), *aug
                  ])

    if mode in ['base', 'mix', 'cutmix']:
Exemplo n.º 6
0
closure_gen = base_closure(tb.X, tb.MODEL, tb.Y_PRED, tb.Y_TRUE, tb.CRITERION,
                           tb.LOSS, GEN_OPT)
closure_disc = base_closure(tb.Y_PRED, DISC_MODEL, None, DISC_IMGS, DISC_CRIT,
                            tb.LOSS, DISC_OPT)


def closure(state):
    closure_gen(state)
    state[GEN_OPT].step()
    closure_disc(state)
    state[DISC_OPT].step()


from torchbearer.metrics import mean, running_mean
metrics = ['loss', mean(running_mean(D_LOSS)), mean(running_mean(G_LOSS))]

trial = tb.Trial(generator,
                 None,
                 criterion=gen_crit,
                 metrics=metrics,
                 callbacks=[saver_callback])
trial.with_train_generator(dataloader, steps=200000)
trial.to(device)

new_keys = {
    DISC_MODEL: discriminator.to(device),
    DISC_OPT: optimizer_D,
    GEN_OPT: optimizer_G,
    DISC_CRIT: disc_crit
}