def test_single_task_training(self): iris = tfds.load('iris') train_ds = iris['train'].batch(32).repeat() model = tf.keras.Sequential([ tf.keras.Input(shape=(4, ), name='features'), tf.keras.layers.Dense(10, activation=tf.nn.relu), tf.keras.layers.Dense(10, activation=tf.nn.relu), tf.keras.layers.Dense(3), tf.keras.layers.Softmax(), ]) trainer = single_task_trainer.SingleTaskTrainer( train_ds, label_key='label', model=model, loss_fn=tf.keras.losses.sparse_categorical_crossentropy, optimizer=tf.keras.optimizers.SGD(learning_rate=0.01)) controller = orbit.Controller(trainer=trainer, steps_per_loop=100, global_step=trainer.optimizer.iterations) controller.train(1) start_loss = trainer.train_loss.result().numpy() controller.train(500) end_loss = trainer.train_loss.result().numpy() # Assert that the model has trained 'significantly' - that the loss # has dropped by over 50%. self.assertLess(end_loss, start_loss / 2)
def test_single_task_evaluation(self): iris = tfds.load('iris') train_ds = iris['train'].batch(32) model = tf.keras.Sequential([ tf.keras.Input(shape=(4, ), name='features'), tf.keras.layers.Dense(10, activation=tf.nn.relu), tf.keras.layers.Dense(10, activation=tf.nn.relu), tf.keras.layers.Dense(3) ]) trainer = single_task_trainer.SingleTaskTrainer( train_ds, label_key='label', model=model, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True), optimizer=tf.keras.optimizers.SGD(learning_rate=0.01)) evaluator = single_task_evaluator.SingleTaskEvaluator( train_ds, label_key='label', model=model, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) controller = orbit.Controller(trainer=trainer, evaluator=evaluator, steps_per_loop=100, global_step=trainer.optimizer.iterations) controller.train(train_ds.cardinality().numpy()) controller.evaluate() accuracy = evaluator.metrics[0].result().numpy() self.assertGreater(0.925, accuracy)
def run_experiment(distribution_strategy: tf.distribute.Strategy, task: base_task.Task, mode: str, params: config_definitions.ExperimentConfig, model_dir: str, run_post_eval: bool = False, save_summary: bool = True) \ -> Tuple[tf.keras.Model, Mapping[str, Any]]: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A Task instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. Returns: A 2-tuple of (model, eval_logs). model: `tf.keras.Model` instance. eval_logs: returns eval metrics logs when run_post_eval is set to True, otherwise, returns {}. """ with distribution_strategy.scope(): trainer = train_utils.create_trainer( params, task, model_dir=model_dir, train='train' in mode, evaluate=('eval' in mode) or run_post_eval, checkpoint_exporter=maybe_create_best_ckpt_exporter( params, model_dir)) if trainer.checkpoint: checkpoint_manager = tf.train.CheckpointManager( trainer.checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=trainer.global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize) else: checkpoint_manager = None controller = orbit.Controller( distribution_strategy, trainer=trainer if 'train' in mode else None, evaluator=trainer, global_step=trainer.global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None, eval_summary_dir=os.path.join(model_dir, 'validation') if (save_summary) else None, summary_interval=params.trainer.summary_interval if (save_summary) else None) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if trainer.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) if run_post_eval: with distribution_strategy.scope(): return trainer.model, trainer.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) else: return trainer.model, {}
def run(flags_obj): """Run ResNet ImageNet training and eval loop using custom training loops. Args: flags_obj: An object containing parsed flag values. Raises: ValueError: If fp16 is passed as it is not currently supported. Returns: Dictionary of training and eval stats. """ keras_utils.set_session_config() performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj)) if tf.config.list_physical_devices('GPU'): if flags_obj.tf_gpu_thread_mode: keras_utils.set_gpu_thread_mode_and_count( per_gpu_thread_count=flags_obj.per_gpu_thread_count, gpu_thread_mode=flags_obj.tf_gpu_thread_mode, num_gpus=flags_obj.num_gpus, datasets_num_private_threads=flags_obj. datasets_num_private_threads) common.set_cudnn_batchnorm_mode() data_format = flags_obj.data_format if data_format is None: data_format = ('channels_first' if tf.config.list_physical_devices('GPU') else 'channels_last') tf.keras.backend.set_image_data_format(data_format) strategy = distribute_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_obj.num_gpus, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs, tpu_address=flags_obj.tpu) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations( flags_obj) if flags_obj.steps_per_loop is None: steps_per_loop = per_epoch_steps elif flags_obj.steps_per_loop > per_epoch_steps: steps_per_loop = per_epoch_steps logging.warn('Setting steps_per_loop to %d to respect epoch boundary.', steps_per_loop) else: steps_per_loop = flags_obj.steps_per_loop logging.info( 'Training %d epochs, each epoch has %d steps, ' 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps, train_epochs * per_epoch_steps, eval_steps) time_callback = keras_utils.TimeHistory( flags_obj.batch_size, flags_obj.log_steps, logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None) with distribute_utils.get_strategy_scope(strategy): runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback, per_epoch_steps) eval_interval = flags_obj.epochs_between_evals * per_epoch_steps checkpoint_interval = (steps_per_loop * 5 if flags_obj.enable_checkpoint_and_export else None) summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None checkpoint_manager = tf.train.CheckpointManager( runnable.checkpoint, directory=flags_obj.model_dir, max_to_keep=10, step_counter=runnable.global_step, checkpoint_interval=checkpoint_interval) resnet_controller = orbit.Controller( strategy=strategy, trainer=runnable, evaluator=runnable if not flags_obj.skip_eval else None, global_step=runnable.global_step, steps_per_loop=steps_per_loop, checkpoint_manager=checkpoint_manager, summary_interval=summary_interval, summary_dir=flags_obj.model_dir, eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval')) time_callback.on_train_begin() if not flags_obj.skip_eval: resnet_controller.train_and_evaluate(train_steps=per_epoch_steps * train_epochs, eval_steps=eval_steps, eval_interval=eval_interval) else: resnet_controller.train(steps=per_epoch_steps * train_epochs) time_callback.on_train_end() stats = build_stats(runnable, time_callback) return stats
def run_experiment(*, distribution_strategy: tf.distribute.Strategy, task: multitask.MultiTask, model: base_model.MultiTaskBaseModel, mode: str, params: configs.MultiTaskExperimentConfig, model_dir: str) -> base_model.MultiTaskBaseModel: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A MultiTaskTask instance. model: A MultiTaskBaseModel instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. Returns: model: `base_model.MultiTaskBaseModel` instance. """ is_training = 'train' in mode is_eval = 'eval' in mode with distribution_strategy.scope(): optimizer = task.create_optimizer(params.trainer.optimizer_config, params.runtime) kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer) if params.trainer.trainer_type == 'interleaving': sampler = task_sampler.get_task_sampler(params.trainer.task_sampler, task.task_weights) kwargs.update(dict(task_sampler=sampler)) trainer = TRAINERS[params.trainer.trainer_type]( **kwargs) if is_training else None if is_eval: eval_steps = task.task_eval_steps evaluator = evaluator_lib.MultiTaskEvaluator( eval_tasks=task.tasks.values(), model=model, eval_steps=eval_steps, global_step=trainer.global_step if is_training else None, checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( params, model_dir)) else: evaluator = None if trainer: checkpoint = trainer.checkpoint global_step = trainer.global_step else: checkpoint = evaluator.checkpoint global_step = evaluator.global_step # TODO(hongkuny,haozhangthu): Revisit initialization method. checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=model.initialize) controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer, evaluator=evaluator, global_step=global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train'), eval_summary_dir=os.path.join(model_dir, 'validation'), summary_interval=params.trainer.summary_interval) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if evaluator.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) return model
def run_experiment_with_multitask_eval( *, distribution_strategy: tf.distribute.Strategy, train_task: base_task.Task, eval_tasks: List[base_task.Task], mode: str, params: configs.MultiEvalExperimentConfig, model_dir: str, run_post_eval: bool = False, save_summary: bool = True, trainer: Optional[core_lib.Trainer] = None) -> tf.keras.Model: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. train_task: A base_task.Task instance. eval_tasks: A list of evaluation tasks. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: MultiEvalExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. trainer: the core_lib.Trainer instance. It should be created within the strategy.scope(). If not provided, an instance will be created by default if `mode` contains 'train'. Returns: model: `tf.keras.Model` instance. """ is_training = 'train' in mode is_eval = 'eval' in mode with distribution_strategy.scope(): if is_training: trainer = trainer or core_lib.Trainer( config=params, task=train_task, model=train_task.build_model(), optimizer=train_task.create_optimizer(params.trainer.optimizer_config, params.runtime), train=True, evaluate=False) else: trainer = None model = trainer.model if trainer else train_task.build_model() if is_eval: eval_steps = dict([(task_routine.task_config.name, task_routine.eval_steps) for task_routine in params.eval_tasks]) evaluator = evaluator_lib.MultiTaskEvaluator( eval_tasks=eval_tasks, model=model, global_step=trainer.global_step if is_training else None, eval_steps=eval_steps, checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( params, model_dir)) else: evaluator = None if trainer: checkpoint = trainer.checkpoint global_step = trainer.global_step else: checkpoint = evaluator.checkpoint global_step = evaluator.global_step checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize if trainer else None) controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer, evaluator=evaluator, global_step=global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train') if save_summary else None, eval_summary_dir=os.path.join(model_dir, 'validation') if (save_summary) else None, summary_interval=params.trainer.summary_interval if (save_summary) else None) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if evaluator.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) if run_post_eval: return model, evaluator.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) else: return model, {}
def run_experiment( distribution_strategy: tf.distribute.Strategy, task: base_task.Task, mode: str, params: config_definitions.ExperimentConfig, model_dir: str, run_post_eval: bool = False, save_summary: bool = True, trainer: Optional[base_trainer.Trainer] = None ) -> Tuple[tf.keras.Model, Mapping[str, Any]]: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A Task instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. trainer: the base_trainer.Trainer instance. It should be created within the strategy.scope(). Returns: A 2-tuple of (model, eval_logs). model: `tf.keras.Model` instance. eval_logs: returns eval metrics logs when run_post_eval is set to True, otherwise, returns {}. """ with distribution_strategy.scope(): if not trainer: trainer = train_utils.create_trainer( params, task, train='train' in mode, evaluate=('eval' in mode) or run_post_eval, checkpoint_exporter=maybe_create_best_ckpt_exporter( params, model_dir)) if trainer.checkpoint: checkpoint_manager = tf.train.CheckpointManager( trainer.checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=trainer.global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize) # Adds recovery handling. trainer.add_recovery(params.trainer, checkpoint_manager=checkpoint_manager) else: checkpoint_manager = None #Create logs matching tensorboard log parser format #see tensorboard_for_parser.md hparams = { "batch_size": params.task.train_data.global_batch_size, "precision": params.runtime.mixed_precision_dtype } controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer if 'train' in mode else None, evaluator=trainer, global_step=trainer.global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=model_dir if (save_summary) else None, eval_summary_dir=os.path.join( model_dir, params.trainer.validation_summary_subdir) if (save_summary) else None, summary_interval=params.trainer.summary_interval if (save_summary) else None, hparams=hparams if (save_summary) else None, train_actions=None, eval_actions=actions.get_eval_actions(params, trainer, model_dir)) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if (params.runtime.dump_config): from TensorFlow.common.debug import dump_callback with dump_callback( params.runtime.dump_config ) if params.runtime.dump_config else contextlib.ExitStack(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if trainer.global_step.numpy( ) >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) num_params = train_utils.try_count_params(trainer.model) if num_params is not None: logging.info('Number of trainable params in model: %f Millions.', num_params / 10.**6) if run_post_eval: with distribution_strategy.scope(): return trainer.model, trainer.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) else: return trainer.model, {}
def train(model_optimizer_fn, train_steps, eval_steps, steps_between_evals, train_dataset, test_dataset, experiment_dir): """Perform training. Arguments: model_optimizer_fn: Function that returns a tuple containing the model and its optimizer. train_steps: Total number of steps to train for. eval_steps: Number of steps to evaluate for. steps_between_evals: Number of steps to train for between evaluations. train_dataset: Dataset to use for training. test_dataset: Size of test dataset. experiment_dir: Directory in which to save results. """ test_dataset_orig = test_dataset if FLAGS.tpu is not None: resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=FLAGS.tpu) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) train_dataset = strategy.experimental_distribute_dataset(train_dataset) test_dataset = strategy.experimental_distribute_dataset(test_dataset) summary_and_checkpoint_dir = os.path.join(experiment_dir, 'checkpoints') best_checkpoint_path = os.path.join(experiment_dir, 'checkpoints', 'best') final_model_dir = os.path.join(experiment_dir, 'final_model') best_model_dir = os.path.join(experiment_dir, 'best_model') # Load previous best accuracy and previous eval step out of summaries. summaries, summary_steps = get_summaries_from_dir(summary_and_checkpoint_dir) previous_best_accuracy = 0.0 previous_eval_steps = 0 if summaries[EVAL_ACCURACY_KEY]: previous_best_accuracy = max(summaries[EVAL_ACCURACY_KEY]) previous_eval_steps = max(summary_steps[EVAL_ACCURACY_KEY]) try: with strategy.scope() if FLAGS.tpu else contextlib.suppress(): model, optimizer = model_optimizer_fn() def _loss_fn(labels, logits): """Compute total loss.""" loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels, logits)) # Add weight decay losses to final loss. return loss + tf.reduce_sum(model.losses) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) manager = tf.train.CheckpointManager( checkpoint=checkpoint, directory=summary_and_checkpoint_dir, max_to_keep=(None if FLAGS.mode == 'train' else FLAGS.checkpoints_to_keep), step_counter=optimizer.iterations, checkpoint_interval=steps_between_evals) trainer = single_task_trainer.SingleTaskTrainer( train_dataset=train_dataset, label_key='label', model=model, loss_fn=_loss_fn, optimizer=optimizer, metrics=[ tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy/train'), tf.keras.metrics.SparseCategoricalCrossentropy(name='loss/train'), ]) evaluator = single_task_evaluator.SingleTaskEvaluator( eval_dataset=test_dataset, label_key='label', model=model, metrics=[ tf.keras.metrics.SparseCategoricalAccuracy( name=(EVAL_ACCURACY_KEY)), tf.keras.metrics.SparseCategoricalCrossentropy(name='loss/eval'), ]) controller = orbit.Controller( trainer=trainer, evaluator=evaluator, steps_per_loop=steps_between_evals, global_step=optimizer.iterations, checkpoint_manager=manager) controller.restore_checkpoint() while optimizer.iterations < train_steps: current_steps = optimizer.iterations.numpy() if ('evaluate' in FLAGS.mode and (current_steps // steps_between_evals > previous_eval_steps // steps_between_evals)): logging.info('Skipping training because eval is out-of-date.') else: current_train_steps = min( (current_steps // steps_between_evals + 1) * steps_between_evals, train_steps) if 'train' in FLAGS.mode: controller.train(current_train_steps) elif 'evaluate' in FLAGS.mode: next_checkpoint_path = os.path.join( summary_and_checkpoint_dir, 'ckpt-{}'.format(current_train_steps)) while not tf.io.gfile.exists(next_checkpoint_path + '.index'): logging.info('Checkpoint %s not yet ready.', next_checkpoint_path) time.sleep(15) checkpoint.restore(next_checkpoint_path) if 'evaluate' in FLAGS.mode: controller.evaluate(eval_steps) current_accuracy = evaluator.eval_end()[EVAL_ACCURACY_KEY] current_train_loss = trainer.train_loop_end()['loss/train'] previous_eval_steps = optimizer.iterations.numpy() if current_accuracy > previous_best_accuracy: logging.info( 'New accuracy %.4f beats best previous accuracy %.4f; saving ' 'new best checkpoint.', current_accuracy, previous_best_accuracy) previous_best_accuracy = current_accuracy checkpoint.write(best_checkpoint_path) if FLAGS.mode == 'evaluate': # Delete checkpoints if we have hit max. We do this in the eval job # to make sure that we aren't deleting checkpoints we haven't yet # evaluated. checkpoint_paths = tf.io.gfile.glob( os.path.join(summary_and_checkpoint_dir, 'ckpt-*.index')) checkpoint_paths_nums = [ (int(re.search(r'/ckpt-([0-9]+).index$', x).group(1)), x) for x in checkpoint_paths ] checkpoint_paths_nums.sort() for num, path in checkpoint_paths_nums[:-FLAGS.checkpoints_to_keep]: if num <= optimizer.iterations.numpy(): # Don't delete unevaluated checkpoints. logging.info('Removing old checkpoint %s.', path) tf.io.gfile.remove(path) if 'evaluate' in FLAGS.mode: # Save final model. tf.io.gfile.mkdir(final_model_dir) save_predictions(model, test_dataset_orig, final_model_dir) tf.keras.models.save_model(model, final_model_dir) # At end of training, load best checkpoint and write Keras saved model. # We do not do this during training because it is very slow. checkpoint.restore(best_checkpoint_path) tf.io.gfile.mkdir(best_model_dir) save_predictions(model, test_dataset_orig, best_model_dir) tf.keras.models.save_model(model, best_model_dir) except tf.errors.UnavailableError: logging.info('Lost contact with TPU; restaarting.') sys.exit(42)
def main(_): # Set up experiment params and load the configs from file/files. experiment_params = params.EdgeTPUBERTCustomParams() experiment_params = utils.config_override(experiment_params, FLAGS) model_dir = utils.get_model_dir(experiment_params, FLAGS) distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy=experiment_params.runtime.distribution_strategy, all_reduce_alg=experiment_params.runtime.all_reduce_alg, num_gpus=experiment_params.runtime.num_gpus, tpu_address=experiment_params.runtime.tpu_address) with distribution_strategy.scope(): teacher_model = model_builder.build_bert_pretrainer( pretrainer_cfg=experiment_params.teacher_model, quantization_friendly=False, name='teacher') student_model = model_builder.build_bert_pretrainer( pretrainer_cfg=experiment_params.student_model, quantization_friendly=True, name='student') # Load model weights. teacher_ckpt_dir_or_file = experiment_params.teacher_model_init_checkpoint if not teacher_ckpt_dir_or_file: raise ValueError( '`teacher_model_init_checkpoint` is not specified.') utils.load_checkpoint(teacher_model, teacher_ckpt_dir_or_file) student_ckpt_dir_or_file = experiment_params.student_model_init_checkpoint if not student_ckpt_dir_or_file: # Makes sure the pretrainer variables are created. _ = student_model(student_model.inputs) logging.warn( 'No student checkpoint is provided, training might take ' 'much longer before converging.') else: utils.load_checkpoint(student_model, student_ckpt_dir_or_file) runner = mobilebert_edgetpu_trainer.MobileBERTEdgeTPUDistillationTrainer( teacher_model=teacher_model, student_model=student_model, strategy=distribution_strategy, experiment_params=experiment_params, export_ckpt_path=model_dir) # Save checkpoint for preemption handling. # Checkpoint for downstreaming tasks are saved separately inside the # runner's train_loop_end() function. checkpoint = tf.train.Checkpoint( teacher_model=runner.teacher_model, student_model=runner.student_model, layer_wise_optimizer=runner.layer_wise_optimizer, e2e_optimizer=runner.e2e_optimizer, current_step=runner.current_step) checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=5, step_counter=runner.current_step, checkpoint_interval=20000, init_fn=None) controller = orbit.Controller( trainer=runner, evaluator=runner, global_step=runner.current_step, strategy=distribution_strategy, steps_per_loop=experiment_params.orbit_config.steps_per_loop, summary_dir=os.path.join(model_dir, 'train'), eval_summary_dir=os.path.join(model_dir, 'eval'), checkpoint_manager=checkpoint_manager) if FLAGS.mode == 'train': controller.train(steps=experiment_params.orbit_config.total_steps) else: raise ValueError('Unsupported mode, only support `train`')
def run_experiment_wtih_multitask_eval( *, distribution_strategy: tf.distribute.Strategy, train_task: base_task.Task, eval_tasks: multitask.MultiTask, mode: str, params: configs.MultiEvalExperimentConfig, model_dir: str) -> tf.keras.Model: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. train_task: A base_task.Task instance. eval_tasks: A multitask.MultiTask with evaluation tasks. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: MultiEvalExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. Returns: model: `tf.keras.Model` instance. """ is_training = 'train' in mode is_eval = 'eval' in mode with distribution_strategy.scope(): optimizer = train_task.create_optimizer(params.trainer, params.runtime) model = train_task.build_model() if is_training: trainer = core_lib.Trainer(config=params, task=train_task, model=model, optimizer=optimizer, train=True, evaluate=False) else: trainer = None if is_eval: evaluator = evaluator_lib.MultiTaskEvaluator( task=eval_tasks, model=model, global_step=trainer.global_step if is_training else None) else: evaluator = None if trainer: checkpoint = trainer.checkpoint global_step = trainer.global_step else: checkpoint = evaluator.checkpoint global_step = evaluator.global_step checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize if trainer else None) controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer, evaluator=evaluator, global_step=global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train'), eval_summary_dir=os.path.join(model_dir, 'validation'), summary_interval=params.trainer.summary_interval) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if evaluator.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) return model
global_step = evaluator.global_step checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize if trainer else None) controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer, evaluator=evaluator, global_step=global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train') if save_summary else None, eval_summary_dir=os.path.join(model_dir, 'validation') if (save_summary) else None, summary_interval=params.trainer.summary_interval if (save_summary) else None) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval)