def test_multitask_interleaving_trainer(self, distribution):
     with distribution.scope():
         tasks = [
             test_utils.MockFooTask(params=test_utils.FooConfig(),
                                    name="foo"),
             test_utils.MockBarTask(params=test_utils.BarConfig(),
                                    name="bar")
         ]
         test_multitask = multitask.MultiTask(tasks=tasks)
         test_optimizer = tf.keras.optimizers.SGD(0.1)
         model = test_utils.MockMultiTaskModel()
         sampler = task_sampler.UniformTaskSampler(
             task_weights=test_multitask.task_weights)
         test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
             multi_task=test_multitask,
             multi_task_model=model,
             optimizer=test_optimizer,
             task_sampler=sampler)
         results = test_trainer.train(
             tf.convert_to_tensor(5, dtype=tf.int32))
         self.assertContainsSubset(["training_loss", "bar_acc"],
                                   results["bar"].keys())
         self.assertContainsSubset(["training_loss", "foo_acc"],
                                   results["foo"].keys())
         self.assertNotIn("total_loss", results)
 def test_multitask_joint_trainer(self, distribution):
   with distribution.scope():
     tasks = [
         test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
         test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
     ]
     task_weights = {"foo": 1.0, "bar": 1.0}
     test_multitask = multitask.MultiTask(
         tasks=tasks, task_weights=task_weights)
     test_optimizer = tf.keras.optimizers.SGD(0.1)
     model = test_utils.MockMultiTaskModel()
     test_trainer = base_trainer.MultiTaskBaseTrainer(
         multi_task=test_multitask,
         multi_task_model=model,
         optimizer=test_optimizer)
     results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
     self.assertContainsSubset(["training_loss", "bar_acc"],
                               results["bar"].keys())
     self.assertContainsSubset(["training_loss", "foo_acc"],
                               results["foo"].keys())