def test_trainer_with_configs(self, distribution):
     config = configs.MultiTaskConfig(task_routines=(
         configs.TaskRoutine(task_name="foo",
                             task_config=test_utils.FooConfig(),
                             task_weight=3.0),
         configs.TaskRoutine(task_name="bar",
                             task_config=test_utils.BarConfig(),
                             task_weight=1.0)))
     with distribution.scope():
         test_multitask = multitask.MultiTask.from_config(config)
     test_optimizer = tf.keras.optimizers.SGD(0.1)
     model = test_utils.MockMultiTaskModel()
     num_step = 1000
     sampler = task_sampler.AnnealingTaskSampler(
         task_weights=test_multitask.task_weights,
         steps_per_epoch=num_step / 5,
         total_steps=num_step)
     test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
         multi_task=test_multitask,
         multi_task_model=model,
         optimizer=test_optimizer,
         task_sampler=sampler)
     results = test_trainer.train(
         tf.convert_to_tensor(num_step, dtype=tf.int32))
     self.assertContainsSubset(["training_loss", "bar_acc"],
                               results["bar"].keys())
     self.assertContainsSubset(["training_loss", "foo_acc"],
                               results["foo"].keys())
     self.assertEqual(test_trainer.global_step.numpy(), num_step)
     bar_sampled_step = test_trainer.task_step_counter("bar").numpy()
     foo_sampled_step = test_trainer.task_step_counter("foo").numpy()
     self.assertEqual(bar_sampled_step + foo_sampled_step, num_step)
 def test_multitask_interleaving_trainer(self, distribution):
     with distribution.scope():
         tasks = [
             test_utils.MockFooTask(params=test_utils.FooConfig(),
                                    name="foo"),
             test_utils.MockBarTask(params=test_utils.BarConfig(),
                                    name="bar")
         ]
         test_multitask = multitask.MultiTask(tasks=tasks)
         test_optimizer = tf.keras.optimizers.SGD(0.1)
         model = test_utils.MockMultiTaskModel()
         sampler = task_sampler.UniformTaskSampler(
             task_weights=test_multitask.task_weights)
         test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
             multi_task=test_multitask,
             multi_task_model=model,
             optimizer=test_optimizer,
             task_sampler=sampler)
         results = test_trainer.train(
             tf.convert_to_tensor(5, dtype=tf.int32))
         self.assertContainsSubset(["training_loss", "bar_acc"],
                                   results["bar"].keys())
         self.assertContainsSubset(["training_loss", "foo_acc"],
                                   results["foo"].keys())
         self.assertNotIn("total_loss", results)
 def test_trainer_with_configs(self):
   config = configs.MultiTaskConfig(
       task_routines=(configs.TaskRoutine(
           task_name="foo",
           task_config=test_utils.FooConfig(),
           task_weight=0.5),
                      configs.TaskRoutine(
                          task_name="bar",
                          task_config=test_utils.BarConfig(),
                          task_weight=0.5)))
   test_multitask = multitask.MultiTask.from_config(config)
   test_optimizer = tf.keras.optimizers.SGD(0.1)
   model = test_utils.MockMultiTaskModel()
   test_trainer = base_trainer.MultiTaskBaseTrainer(
       multi_task=test_multitask,
       multi_task_model=model,
       optimizer=test_optimizer)
   results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
   self.assertContainsSubset(["training_loss", "bar_acc"],
                             results["bar"].keys())
   self.assertContainsSubset(["training_loss", "foo_acc"],
                             results["foo"].keys())
   self.assertEqual(test_multitask.task_weight("foo"), 0.5)
   self.assertEqual(test_trainer.global_step.numpy(), 5)
   self.assertIn("learning_rate", results)
 def test_multitask_joint_trainer(self, distribution):
   with distribution.scope():
     tasks = [
         test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
         test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
     ]
     task_weights = {"foo": 1.0, "bar": 1.0}
     test_multitask = multitask.MultiTask(
         tasks=tasks, task_weights=task_weights)
     test_optimizer = tf.keras.optimizers.SGD(0.1)
     model = test_utils.MockMultiTaskModel()
     test_trainer = base_trainer.MultiTaskBaseTrainer(
         multi_task=test_multitask,
         multi_task_model=model,
         optimizer=test_optimizer)
     results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
     self.assertContainsSubset(["training_loss", "bar_acc"],
                               results["bar"].keys())
     self.assertContainsSubset(["training_loss", "foo_acc"],
                               results["foo"].keys())
 def test_end_to_end(self, distribution_strategy, flag_mode):
   model_dir = self.get_temp_dir()
   experiment_config = configs.MultiTaskExperimentConfig(
       task=configs.MultiTaskConfig(
           task_routines=(
               configs.TaskRoutine(
                   task_name='foo', task_config=test_utils.FooConfig()),
               configs.TaskRoutine(
                   task_name='bar', task_config=test_utils.BarConfig()))))
   experiment_config = params_dict.override_params_dict(
       experiment_config, self._test_config, is_strict=False)
   with distribution_strategy.scope():
     test_multitask = multitask.MultiTask.from_config(experiment_config.task)
     model = test_utils.MockMultiTaskModel()
   train_lib.run_experiment(
       distribution_strategy=distribution_strategy,
       task=test_multitask,
       model=model,
       mode=flag_mode,
       params=experiment_config,
       model_dir=model_dir)