예제 #1
0
 def decorator(*args, **kwargs):
     misc.enable_mixed_precision(force=True)
     try:
         return fn(*args, **kwargs)
     finally:
         misc.disable_mixed_precision()
예제 #2
0
    def train(
        self,
        num_devices=1,
        with_eval=False,
        checkpoint_path=None,
        hvd=None,
        return_summary=False,
        fallback_to_cpu=True,
    ):
        """Runs the training loop.

        Args:
          num_devices: Number of devices to use for training.
          with_eval: Enable evaluation during training.
          checkpoint_path: The checkpoint path to load the model weights from.
          hvd: Optional Horovod module.
          return_summary: Return a summary of the training from this function.
          fallback_to_cpu: If no GPU is detected, allow the training to run on CPU.

        Returns:
          The path to the final model directory and, if :obj:`return_summary` is set,
          a dictionary with various training statistics.
        """
        if hvd is None:
            num_replicas = num_devices
            is_master = True
        else:
            if num_devices > 1:
                raise ValueError(
                    "num_devices (or num_gpus) should be set to 1 when using Horovod"
                )
            num_replicas = hvd.size()
            is_master = hvd.rank() == 0

        devices = misc.get_devices(count=num_devices,
                                   fallback_to_cpu=fallback_to_cpu)

        config = self._finalize_config(training=True,
                                       num_replicas=num_replicas,
                                       num_devices=num_devices)

        mixed_precision = self._mixed_precision and misc.enable_mixed_precision(
        )
        model = self._init_model(config)
        optimizer = model.get_optimizer()

        data_config = config["data"]
        train_config = config["train"]
        eval_config = config["eval"]

        batch_type = train_config["batch_type"]
        batch_size_multiple = 8 if mixed_precision and batch_type == "tokens" else 1

        dataset_fn = (
            lambda input_context: model.examples_inputter.
            make_training_dataset(
                data_config["train_features_file"],
                data_config.get("train_labels_file"),
                train_config["batch_size"],
                batch_type=batch_type,
                batch_size_multiple=batch_size_multiple,
                shuffle_buffer_size=train_config["sample_buffer_size"],
                length_bucket_width=train_config["length_bucket_width"],
                maximum_features_length=train_config.get(
                    "maximum_features_length"),
                maximum_labels_length=train_config.get("maximum_labels_length"
                                                       ),
                single_pass=train_config.get("single_pass", False),
                num_shards=input_context.num_input_pipelines,
                shard_index=input_context.input_pipeline_id,
                prefetch_buffer_size=train_config.get("prefetch_buffer_size"),
                cardinality_multiple=input_context.num_replicas_in_sync,
                weights=data_config.get("train_files_weights"),
                batch_autotune_mode=train_config.get("batch_autotune_mode"),
            ))

        checkpoint = None
        evaluator = None
        if is_master:
            checkpoint = checkpoint_util.Checkpoint.from_config(
                config, model, optimizer=optimizer)
            checkpoint.restore(
                checkpoint_path=checkpoint_path,
                weights_only=checkpoint_path is not None,
            )
            if with_eval:
                evaluator = evaluation.Evaluator.from_config(model, config)

        # Set gradients accumulation based on the requested effective batch size.
        if train_config.get("effective_batch_size") is not None:
            accum_steps = _count_batch_accum(
                train_config["batch_size"],
                train_config["effective_batch_size"],
                num_replicas=num_replicas,
            )
            tf.get_logger().info(
                "Accumulate gradients of %d iterations to reach effective batch size of %d",
                accum_steps,
                train_config["effective_batch_size"],
            )
        else:
            accum_steps = 1

        if hvd is not None:
            trainer = training_util.HorovodTrainer(model,
                                                   optimizer,
                                                   hvd,
                                                   checkpoint=checkpoint)
        elif num_devices > 1:
            trainer = training_util.MirroredStrategyTrainer(
                model, optimizer, checkpoint=checkpoint, devices=devices)
        else:
            trainer = training_util.Trainer(model,
                                            optimizer,
                                            checkpoint=checkpoint)

        summary = trainer(
            dataset_fn,
            max_step=train_config.get("max_step"),
            accum_steps=accum_steps,
            report_steps=train_config.get("save_summary_steps", 100),
            save_steps=train_config.get("save_checkpoints_steps", 5000),
            evaluator=evaluator,
            eval_steps=eval_config.get("steps", 5000),
            moving_average_decay=train_config.get("moving_average_decay"),
        )

        average_last_checkpoints = train_config.get("average_last_checkpoints",
                                                    0)
        if checkpoint is None:
            output_dir = None
        elif average_last_checkpoints > 0:
            output_dir = self.average_checkpoints(
                os.path.join(checkpoint.model_dir, "avg"),
                max_count=average_last_checkpoints,
            )
        else:
            output_dir = checkpoint.model_dir

        if mixed_precision:
            misc.disable_mixed_precision()

        if return_summary:
            return output_dir, summary
        return output_dir