def test_uniform_sample_distribution(self): uniform_sampler = sampler.get_task_sampler( configs.TaskSamplingConfig(type='uniform'), self._task_weights) for step in range(5): cumulative_distribution = uniform_sampler.task_cumulative_distribution( tf.constant(step, dtype=tf.int64)) self.assertAllClose([0.333333, 0.666666, 1.0], cumulative_distribution.numpy())
def test_proportional_sample_distribution(self): prop_sampler = sampler.get_task_sampler( configs.TaskSamplingConfig( type='proportional', proportional=configs.ProportionalSampleConfig(alpha=2.0)), self._task_weights) # CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2])) for step in range(5): cumulative_distribution = prop_sampler.task_cumulative_distribution( tf.constant(step, dtype=tf.int64)) self.assertAllClose([0.07142857, 0.35714286, 1.0], cumulative_distribution.numpy())
def test_annealing_sample_distribution(self): num_epoch = 3 step_per_epoch = 6 annel_sampler = sampler.get_task_sampler( configs.TaskSamplingConfig( type='annealing', annealing=configs.AnnealingSampleConfig( steps_per_epoch=step_per_epoch, total_steps=step_per_epoch * num_epoch)), self._task_weights) global_step = tf.Variable( 0, dtype=tf.int64, name='global_step', trainable=False) expected_cumulative_epochs = [[0.12056106, 0.4387236, 1.0], [0.16666667, 0.5, 1.0], [0.22477472, 0.5654695, 1.0]] for epoch in range(num_epoch): for _ in range(step_per_epoch): cumulative_distribution = annel_sampler.task_cumulative_distribution( tf.constant(global_step, dtype=tf.int64)) global_step.assign_add(1) self.assertAllClose(expected_cumulative_epochs[epoch], cumulative_distribution.numpy())
def run_experiment(*, distribution_strategy: tf.distribute.Strategy, task: multitask.MultiTask, model: base_model.MultiTaskBaseModel, mode: str, params: configs.MultiTaskExperimentConfig, model_dir: str) -> base_model.MultiTaskBaseModel: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A MultiTaskTask instance. model: A MultiTaskBaseModel instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. Returns: model: `base_model.MultiTaskBaseModel` instance. """ is_training = 'train' in mode is_eval = 'eval' in mode with distribution_strategy.scope(): optimizer = task.create_optimizer(params.trainer.optimizer_config, params.runtime) kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer) if params.trainer.trainer_type == 'interleaving': sampler = task_sampler.get_task_sampler(params.trainer.task_sampler, task.task_weights) kwargs.update(dict(task_sampler=sampler)) trainer = TRAINERS[params.trainer.trainer_type]( **kwargs) if is_training else None if is_eval: eval_steps = task.task_eval_steps evaluator = evaluator_lib.MultiTaskEvaluator( eval_tasks=task.tasks.values(), model=model, eval_steps=eval_steps, global_step=trainer.global_step if is_training else None, checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( params, model_dir)) else: evaluator = None if trainer: checkpoint = trainer.checkpoint global_step = trainer.global_step else: checkpoint = evaluator.checkpoint global_step = evaluator.global_step # TODO(hongkuny,haozhangthu): Revisit initialization method. checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=model.initialize) controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer, evaluator=evaluator, global_step=global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train'), eval_summary_dir=os.path.join(model_dir, 'validation'), summary_interval=params.trainer.summary_interval) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if evaluator.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) return model