Exemplo n.º 1
0
 def test_running_mean_metric_factory(self, build_mock):
     metric = metrics.MetricFactory
     build_mock.return_value=metrics.Metric('test')
     out = metrics.running_mean(metric)().build()
     self.assertTrue(isinstance(out, metrics.MetricTree))
     self.assertTrue(isinstance(out.children[0], metrics.ToDict))
     self.assertTrue(isinstance(out.children[0].metric, metrics.RunningMean))
     self.assertTrue(out.children[0].metric.name == 'running_test')
     self.assertTrue(out.root.name == 'test')
Exemplo n.º 2
0
 def test_running_mean_metric_class(self):
     metric = metrics.Metric
     out = metrics.running_mean(batch_size=40, step_size=20)(metric)('test')
     self.assertTrue(isinstance(out, metrics.MetricTree))
     self.assertTrue(isinstance(out.children[0], metrics.ToDict))
     self.assertTrue(isinstance(out.children[0].metric, metrics.RunningMean))
     self.assertTrue(out.children[0].metric._batch_size == 40)
     self.assertTrue(out.children[0].metric._step_size == 20)
     self.assertTrue(out.children[0].metric.name == 'running_test')
     self.assertTrue(out.root.name == 'test')
Exemplo n.º 3
0
 def test_running_mean_metric(self):
     metric = metrics.Metric
     out = metrics.running_mean(batch_size=40, step_size=20)(metric)('test').build()
     self.assertTrue(isinstance(out, metrics.MetricTree))
     self.assertTrue(isinstance(out.children[0], metrics.ToDict))
     self.assertTrue(isinstance(out.children[0].metric, metrics.RunningMean))
     self.assertTrue(out.children[0].metric._batch_size == 40)
     self.assertTrue(out.children[0].metric._step_size == 20)
     self.assertTrue(out.children[0].metric.name == 'running_test')
     self.assertTrue(out.root.name == 'test')
Exemplo n.º 4
0
closure_gen = base_closure(tb.X, tb.MODEL, tb.Y_PRED, tb.Y_TRUE, tb.CRITERION,
                           tb.LOSS, GEN_OPT)
closure_disc = base_closure(tb.Y_PRED, DISC_MODEL, None, DISC_IMGS, DISC_CRIT,
                            tb.LOSS, DISC_OPT)


def closure(state):
    closure_gen(state)
    state[GEN_OPT].step()
    closure_disc(state)
    state[DISC_OPT].step()


from torchbearer.metrics import mean, running_mean
metrics = ['loss', mean(running_mean(D_LOSS)), mean(running_mean(G_LOSS))]

trial = tb.Trial(generator,
                 None,
                 criterion=gen_crit,
                 metrics=metrics,
                 callbacks=[saver_callback])
trial.with_train_generator(dataloader, steps=200000)
trial.to(device)

new_keys = {
    DISC_MODEL: discriminator.to(device),
    DISC_OPT: optimizer_D,
    GEN_OPT: optimizer_G,
    DISC_CRIT: disc_crit
}