def test_multitask_interleaving_trainer(self, distribution): with distribution.scope(): tasks = [ test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"), test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar") ] test_multitask = multitask.MultiTask(tasks=tasks) test_optimizer = tf.keras.optimizers.SGD(0.1) model = test_utils.MockMultiTaskModel() sampler = task_sampler.UniformTaskSampler( task_weights=test_multitask.task_weights) test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer( multi_task=test_multitask, multi_task_model=model, optimizer=test_optimizer, task_sampler=sampler) results = test_trainer.train( tf.convert_to_tensor(5, dtype=tf.int32)) self.assertContainsSubset(["training_loss", "bar_acc"], results["bar"].keys()) self.assertContainsSubset(["training_loss", "foo_acc"], results["foo"].keys()) self.assertNotIn("total_loss", results)
def test_multitask_evaluator(self, distribution): with distribution.scope(): tasks = [ MockTask(params=cfg.TaskConfig(), name="bar"), MockTask(params=cfg.TaskConfig(), name="foo") ] test_multitask = multitask.MultiTask(tasks=tasks) model = MockModel() test_evaluator = evaluator.MultiTaskEvaluator( task=test_multitask, model=model) results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys()) self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys()) self.assertEqual(results["bar"]["validation_loss"], 0.0) self.assertEqual(results["foo"]["validation_loss"], 1.0)
def test_multitask_evaluator_numpy_metrics(self, distribution): with distribution.scope(): tasks = [ MockTask(params=cfg.TaskConfig(), name="bar"), MockTask(params=cfg.TaskConfig(), name="foo") ] test_multitask = multitask.MultiTask(tasks=tasks) model = MockModel() test_evaluator = evaluator.MultiTaskEvaluator( task=test_multitask, model=model) results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) self.assertEqual(results["bar"]["counter"], 5. * distribution.num_replicas_in_sync) self.assertEqual(results["foo"]["counter"], 5. * distribution.num_replicas_in_sync)
def test_multitask_joint_trainer(self, distribution): with distribution.scope(): tasks = [ test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"), test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar") ] task_weights = {"foo": 1.0, "bar": 1.0} test_multitask = multitask.MultiTask( tasks=tasks, task_weights=task_weights) test_optimizer = tf.keras.optimizers.SGD(0.1) model = test_utils.MockMultiTaskModel() test_trainer = base_trainer.MultiTaskBaseTrainer( multi_task=test_multitask, multi_task_model=model, optimizer=test_optimizer) results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) self.assertContainsSubset(["training_loss", "bar_acc"], results["bar"].keys()) self.assertContainsSubset(["training_loss", "foo_acc"], results["foo"].keys())