def decorator(*args, **kwargs): misc.enable_mixed_precision(force=True) try: return fn(*args, **kwargs) finally: misc.disable_mixed_precision()
def train( self, num_devices=1, with_eval=False, checkpoint_path=None, hvd=None, return_summary=False, fallback_to_cpu=True, ): """Runs the training loop. Args: num_devices: Number of devices to use for training. with_eval: Enable evaluation during training. checkpoint_path: The checkpoint path to load the model weights from. hvd: Optional Horovod module. return_summary: Return a summary of the training from this function. fallback_to_cpu: If no GPU is detected, allow the training to run on CPU. Returns: The path to the final model directory and, if :obj:`return_summary` is set, a dictionary with various training statistics. """ if hvd is None: num_replicas = num_devices is_master = True else: if num_devices > 1: raise ValueError( "num_devices (or num_gpus) should be set to 1 when using Horovod" ) num_replicas = hvd.size() is_master = hvd.rank() == 0 devices = misc.get_devices(count=num_devices, fallback_to_cpu=fallback_to_cpu) config = self._finalize_config(training=True, num_replicas=num_replicas, num_devices=num_devices) mixed_precision = self._mixed_precision and misc.enable_mixed_precision( ) model = self._init_model(config) optimizer = model.get_optimizer() data_config = config["data"] train_config = config["train"] eval_config = config["eval"] batch_type = train_config["batch_type"] batch_size_multiple = 8 if mixed_precision and batch_type == "tokens" else 1 dataset_fn = ( lambda input_context: model.examples_inputter. make_training_dataset( data_config["train_features_file"], data_config.get("train_labels_file"), train_config["batch_size"], batch_type=batch_type, batch_size_multiple=batch_size_multiple, shuffle_buffer_size=train_config["sample_buffer_size"], length_bucket_width=train_config["length_bucket_width"], maximum_features_length=train_config.get( "maximum_features_length"), maximum_labels_length=train_config.get("maximum_labels_length" ), single_pass=train_config.get("single_pass", False), num_shards=input_context.num_input_pipelines, shard_index=input_context.input_pipeline_id, prefetch_buffer_size=train_config.get("prefetch_buffer_size"), cardinality_multiple=input_context.num_replicas_in_sync, weights=data_config.get("train_files_weights"), batch_autotune_mode=train_config.get("batch_autotune_mode"), )) checkpoint = None evaluator = None if is_master: checkpoint = checkpoint_util.Checkpoint.from_config( config, model, optimizer=optimizer) checkpoint.restore( checkpoint_path=checkpoint_path, weights_only=checkpoint_path is not None, ) if with_eval: evaluator = evaluation.Evaluator.from_config(model, config) # Set gradients accumulation based on the requested effective batch size. if train_config.get("effective_batch_size") is not None: accum_steps = _count_batch_accum( train_config["batch_size"], train_config["effective_batch_size"], num_replicas=num_replicas, ) tf.get_logger().info( "Accumulate gradients of %d iterations to reach effective batch size of %d", accum_steps, train_config["effective_batch_size"], ) else: accum_steps = 1 if hvd is not None: trainer = training_util.HorovodTrainer(model, optimizer, hvd, checkpoint=checkpoint) elif num_devices > 1: trainer = training_util.MirroredStrategyTrainer( model, optimizer, checkpoint=checkpoint, devices=devices) else: trainer = training_util.Trainer(model, optimizer, checkpoint=checkpoint) summary = trainer( dataset_fn, max_step=train_config.get("max_step"), accum_steps=accum_steps, report_steps=train_config.get("save_summary_steps", 100), save_steps=train_config.get("save_checkpoints_steps", 5000), evaluator=evaluator, eval_steps=eval_config.get("steps", 5000), moving_average_decay=train_config.get("moving_average_decay"), ) average_last_checkpoints = train_config.get("average_last_checkpoints", 0) if checkpoint is None: output_dir = None elif average_last_checkpoints > 0: output_dir = self.average_checkpoints( os.path.join(checkpoint.model_dir, "avg"), max_count=average_last_checkpoints, ) else: output_dir = checkpoint.model_dir if mixed_precision: misc.disable_mixed_precision() if return_summary: return output_dir, summary return output_dir