def test_multitask_interleaving_trainer(self, distribution):
     with distribution.scope():
         tasks = [
             test_utils.MockFooTask(params=test_utils.FooConfig(),
                                    name="foo"),
             test_utils.MockBarTask(params=test_utils.BarConfig(),
                                    name="bar")
         ]
         test_multitask = multitask.MultiTask(tasks=tasks)
         test_optimizer = tf.keras.optimizers.SGD(0.1)
         model = test_utils.MockMultiTaskModel()
         sampler = task_sampler.UniformTaskSampler(
             task_weights=test_multitask.task_weights)
         test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
             multi_task=test_multitask,
             multi_task_model=model,
             optimizer=test_optimizer,
             task_sampler=sampler)
         results = test_trainer.train(
             tf.convert_to_tensor(5, dtype=tf.int32))
         self.assertContainsSubset(["training_loss", "bar_acc"],
                                   results["bar"].keys())
         self.assertContainsSubset(["training_loss", "foo_acc"],
                                   results["foo"].keys())
         self.assertNotIn("total_loss", results)
示例#2
0
 def test_multitask_evaluator(self, distribution):
   with distribution.scope():
     tasks = [
         MockTask(params=cfg.TaskConfig(), name="bar"),
         MockTask(params=cfg.TaskConfig(), name="foo")
     ]
     test_multitask = multitask.MultiTask(tasks=tasks)
     model = MockModel()
     test_evaluator = evaluator.MultiTaskEvaluator(
         task=test_multitask, model=model)
     results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
   self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
   self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
   self.assertEqual(results["bar"]["validation_loss"], 0.0)
   self.assertEqual(results["foo"]["validation_loss"], 1.0)
示例#3
0
 def test_multitask_evaluator_numpy_metrics(self, distribution):
   with distribution.scope():
     tasks = [
         MockTask(params=cfg.TaskConfig(), name="bar"),
         MockTask(params=cfg.TaskConfig(), name="foo")
     ]
     test_multitask = multitask.MultiTask(tasks=tasks)
     model = MockModel()
     test_evaluator = evaluator.MultiTaskEvaluator(
         task=test_multitask, model=model)
     results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
   self.assertEqual(results["bar"]["counter"],
                    5. * distribution.num_replicas_in_sync)
   self.assertEqual(results["foo"]["counter"],
                    5. * distribution.num_replicas_in_sync)
 def test_multitask_joint_trainer(self, distribution):
   with distribution.scope():
     tasks = [
         test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
         test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
     ]
     task_weights = {"foo": 1.0, "bar": 1.0}
     test_multitask = multitask.MultiTask(
         tasks=tasks, task_weights=task_weights)
     test_optimizer = tf.keras.optimizers.SGD(0.1)
     model = test_utils.MockMultiTaskModel()
     test_trainer = base_trainer.MultiTaskBaseTrainer(
         multi_task=test_multitask,
         multi_task_model=model,
         optimizer=test_optimizer)
     results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
     self.assertContainsSubset(["training_loss", "bar_acc"],
                               results["bar"].keys())
     self.assertContainsSubset(["training_loss", "foo_acc"],
                               results["foo"].keys())