Exemplo n.º 1
0
 def average_metrics(self, metrics: Dict[str,
                                         Any]) -> Optional[Dict[str, Any]]:
     check.true(self.hvd_config.use)
     if self.is_chief:
         self.train_process_comm_chief = cast(ipc.ZMQBroadcastServer,
                                              self.train_process_comm_chief)
         logging.debug(
             f"Chief {hvd.rank()} beginning receiving validation metrics.")
         worker_metrics, _ = self.train_process_comm_chief.gather_with_polling(
             lambda: None)
         self.train_process_comm_chief.broadcast(None)
         logging.debug(
             f"Chief {hvd.rank()} done receiving validation metrics.")
         for metric_name in metrics:
             if isinstance(metrics[metric_name], numbers.Number):
                 metrics[metric_name] /= hvd.size()
             else:
                 logging.warning(
                     f"Skipping averaging metric: {metric_name}.")
         for metric_name in metrics.keys():
             for worker_metric in worker_metrics:
                 if isinstance(worker_metric[metric_name], numbers.Number):
                     metrics[metric_name] += worker_metric[
                         metric_name] / hvd.size()
         return metrics
     else:
         self.train_process_comm_worker = cast(
             ipc.ZMQBroadcastClient, self.train_process_comm_worker)
         logging.debug(f"Worker {hvd.rank()} sending metrics.")
         self.train_process_comm_worker.send(metrics)
         # Synchronize with the chief so that there is no risk of accidentally calling send()
         # for a future gather before all workers have called send() on this gather.
         _ = self.train_process_comm_worker.recv()
         return None
Exemplo n.º 2
0
 def average_metrics(self, metrics: Dict[str,
                                         Any]) -> Optional[Dict[str, Any]]:
     # The chief receives the metric from every worker and computes
     # the average.
     check.true(self.hvd_config.use)
     if self.is_chief:
         self.train_process_comm_chief = cast(ipc.ZMQServer,
                                              self.train_process_comm_chief)
         logging.debug(
             f"Chief {hvd.rank()} beginning receiving validation metrics.")
         worker_metrics = self.train_process_comm_chief.barrier(
             num_connections=hvd.size() - 1)
         logging.debug(
             f"Chief {hvd.rank()} done receiving validation metrics.")
         for metric_name in metrics:
             if isinstance(metrics[metric_name], numbers.Number):
                 metrics[metric_name] /= hvd.size()
             else:
                 logging.warning(
                     f"Skipping averaging metric: {metric_name}.")
         for metric_name in metrics.keys():
             for worker_metric in worker_metrics:
                 if isinstance(worker_metric[metric_name], numbers.Number):
                     metrics[metric_name] += worker_metric[
                         metric_name] / hvd.size()
         return metrics
     else:
         self.train_process_comm_worker = cast(
             ipc.ZMQClient, self.train_process_comm_worker)
         logging.debug(f"Worker {hvd.rank()} sending metrics.")
         self.train_process_comm_worker.barrier(message=metrics)
         return None
Exemplo n.º 3
0
    def wrap_dataset(self, dataset: Any, shard_dataset: bool = True) -> Any:
        """
        This should be used to wrap ``tf.data.Dataset`` objects immediately after
        they have been created. Users should use the output of this wrapper as the
        new instance of their dataset. If users create multiple datasets (e.g.,
        one for training and one for validation), users should wrap each dataset
        independently.

        Args:
            dataset: tf.data.Dataset
            shard_dataset:
                When performing multi-slot (distributed) training, this
                controls whether the dataset is sharded so that each training process
                (one per slot) sees unique data. If set to False, users must manually
                configure each process to use unique data.
        """
        if not self.env.managed_training:
            return dataset

        self.dataset_initialized = True
        if not self.hvd_config.use or not isinstance(dataset, tf.data.Dataset) or not shard_dataset:

            if self.hvd_config and not shard_dataset:
                logging.info("Dataset sharding skipped.")
            return dataset

        hvd.require_horovod_type("tensorflow.keras", "TFKerasContext.wrap_dataset was called.")
        dataset = dataset.shard(hvd.size(), hvd.rank())
        logging.debug(f"Sharded dataset to index {hvd.rank()} of {hvd.size()}.")
        return dataset
Exemplo n.º 4
0
    def wrap_dataset(self, dataset: Any, shard_dataset: bool = True) -> Any:
        """
        This should be used to wrap ``tf.data.Dataset`` objects immediately after
        they have been created. Users should use the output of this wrapper as the
        new instance of their dataset. If users create multiple datasets (e.g., one
        for training and one for testing), users should wrap each dataset
        independently. E.g., If users instantiate their training dataset within
        ``build_train_spec()``, they should call ``dataset = wrap_dataset(dataset)``
        prior to passing it into ``tf.estimator.TrainSpec``.

        Args:
            dataset: tf.data.Dataset
            shard_dataset:
                When performing multi-slot (distributed) training, this
                controls whether the dataset is sharded so that each training process
                (one per slot) sees unique data. If set to False, users must manually
                configure each process to use unique data.

        """
        if not self.env.training:
            return dataset

        hvd.require_horovod_type("tensorflow",
                                 "EstimatorContext.wrap_dataset was called.")

        self.dataset_initialized = True
        if not self.hvd_config.use or self.input_from_dataflow or not shard_dataset:
            if self.hvd_config and not shard_dataset:
                logging.info("Dataset sharding skipped.")
            return dataset

        dataset = dataset.shard(hvd.size(), hvd.rank())
        logging.debug(
            f"Sharded dataset to index {hvd.rank()} of {hvd.size()}.")
        return dataset
Exemplo n.º 5
0
    def _combine_metrics_across_processes(
        self, metrics: Dict[str, Any], num_batches: int
    ) -> Tuple[Optional[Dict[str, Any]], Optional[List[int]]]:
        # The chief receives the metric from every other training process.
        check.true(self.hvd_config.use)

        metrics_lists = {}  # type: Dict[str, Any]
        batches_per_process = []  # type: List[int]
        if self.is_chief:
            self.train_process_comm_chief = cast(ipc.ZMQServer,
                                                 self.train_process_comm_chief)
            worker_metrics = self.train_process_comm_chief.barrier(
                num_connections=hvd.size() - 1)
            worker_metrics = cast(List[ipc.MetricsInfo], worker_metrics)

            for metric_name in metrics.keys():
                metrics_lists[metric_name] = [metrics[metric_name]]
                for worker_metric in worker_metrics:
                    metrics_lists[metric_name].append(
                        worker_metric.metrics[metric_name])

            batches_per_process.append(num_batches)
            for worker_metric in worker_metrics:
                batches_per_process.append(worker_metric.num_batches)

            return metrics_lists, batches_per_process
        else:
            self.train_process_comm_worker = cast(
                ipc.ZMQClient, self.train_process_comm_worker)
            self.train_process_comm_worker.barrier(message=ipc.MetricsInfo(
                metrics=metrics, num_batches=num_batches))
            return None, None
Exemplo n.º 6
0
    def average_metrics(self, metrics: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        check.true(self.hvd_config.use)
        all_metrics = self.context.distributed._zmq_gather(metrics)
        if not self.is_chief:
            return None
        assert all_metrics is not None, "chief did not get metrics from _zmq_gather()"

        for key in metrics:
            if isinstance(metrics[key], numbers.Number):
                metrics[key] = sum(m[key] for m in all_metrics) / hvd.size()
            else:
                logging.warning(f"Skipping averaging metric: {key}.")
        return metrics
Exemplo n.º 7
0
    def average_metrics(self, metrics: Dict[str, Any]) -> Optional[Dict[str, Any]]:
        assert (
            self.context.distributed.size > 1
        ), "average_metrics can only be called during distributed training"
        all_metrics = self.context.distributed.gather(metrics)
        if not self.is_chief:
            return None
        assert all_metrics is not None, "chief did not get metrics from gather()"

        for key in metrics:
            if isinstance(metrics[key], numbers.Number):
                metrics[key] = sum(m[key] for m in all_metrics) / hvd.size()
            else:
                logging.warning(f"Skipping averaging metric: {key}.")
        return metrics
    def _average_training_metrics(
            self, per_batch_metrics: List[Dict[str,
                                               Any]]) -> List[Dict[str, Any]]:
        """Average training metrics across GPUs"""
        check.true(self.hvd_config.use,
                   "Can only average training metrics in multi-GPU training.")
        metrics_timeseries = util._list_to_dict(per_batch_metrics)

        # combined_timeseries is: dict[metric_name] -> 2d-array.
        # A measurement is accessed via combined_timeseries[metric_name][process_idx][batch_idx].
        combined_timeseries, _ = self._combine_metrics_across_processes(
            metrics_timeseries, num_batches=len(per_batch_metrics))

        # If the value for a metric is a single-element array, the averaging process will
        # change that into just the element. We record what metrics are single-element arrays
        # so we can wrap them in an array later (for perfect compatibility with non-averaging
        # codepath).
        array_metrics = []
        for metric_name in per_batch_metrics[0].keys():
            if isinstance(per_batch_metrics[0][metric_name], np.ndarray):
                array_metrics.append(metric_name)

        if self.is_chief:
            combined_timeseries_type = Dict[str, List[List[Any]]]
            combined_timeseries = cast(combined_timeseries_type,
                                       combined_timeseries)
            num_batches = len(per_batch_metrics)
            num_processes = hvd.size()
            averaged_metrics_timeseries = {}  # type: Dict[str, List]

            for metric_name in combined_timeseries.keys():
                averaged_metrics_timeseries[metric_name] = []
                for batch_idx in range(num_batches):
                    batch = [
                        combined_timeseries[metric_name][process_idx]
                        [batch_idx] for process_idx in range(num_processes)
                    ]

                    np_batch = np.array(batch)
                    batch_avg = np.mean(
                        np_batch[np_batch != None])  # noqa: E711
                    if metric_name in array_metrics:
                        batch_avg = np.array(batch_avg)
                    averaged_metrics_timeseries[metric_name].append(batch_avg)
            per_batch_metrics = util._dict_to_list(averaged_metrics_timeseries)
        return per_batch_metrics
Exemplo n.º 9
0
    def _set_data_loaders(self) -> None:
        skip_batches = (self.env.first_step() - 1) * self.batches_per_step

        nreplicas = hvd.size() if self.hvd_config.use else 1
        rank = hvd.rank() if self.hvd_config.use else 0

        self.training_loader = self.trial.build_training_data_loader(
        ).get_data_loader(repeat=True,
                          skip=skip_batches,
                          num_replicas=nreplicas,
                          rank=rank)

        validation_dataset = self.trial.build_validation_data_loader()
        if self._evaluate_batch_defined():
            self.validation_loader = validation_dataset.get_data_loader(
                repeat=False, skip=0, num_replicas=nreplicas, rank=rank)
        elif self.is_chief:
            self.validation_loader = validation_dataset.get_data_loader(
                repeat=False, skip=0, num_replicas=1, rank=0)
Exemplo n.º 10
0
    def _set_data_loaders(self) -> None:
        skip_batches = self.env.initial_workload.total_batches_processed

        nreplicas = hvd.size() if self.hvd_config.use else 1
        rank = hvd.rank() if self.hvd_config.use else 0

        self.training_loader = self.trial.build_training_data_loader(
        ).get_data_loader(repeat=True,
                          skip=skip_batches,
                          num_replicas=nreplicas,
                          rank=rank)
        self.context._epoch_len = len(self.training_loader)

        validation_dataset = self.trial.build_validation_data_loader()
        if self._evaluate_batch_defined():
            self.validation_loader = validation_dataset.get_data_loader(
                repeat=False, skip=0, num_replicas=nreplicas, rank=rank)
        elif self.is_chief:
            self.validation_loader = validation_dataset.get_data_loader(
                repeat=False, skip=0, num_replicas=1, rank=0)
Exemplo n.º 11
0
    def wrap_dataset(self, dataset: Any) -> Any:
        """
        This should be used to wrap ``tf.data.Dataset`` objects immediately after
        they have been created. Users should use the output of this wrapper as the
        new instance of their dataset. If users create multiple datasets (e.g., one
        for training and one for testing), users should wrap each dataset
        independently. E.g., If users instantiate their training dataset within
        ``build_train_spec()``, they should call ``dataset = wrap_dataset(dataset)``
        prior to passing it into ``tf.estimator.TrainSpec``.
        """
        hvd.require_horovod_type("tensorflow",
                                 "EstimatorContext.wrap_dataset was called.")

        self.dataset_initialized = True
        if not self.hvd_config.use or self.input_from_dataflow:
            return dataset
        dataset = dataset.shard(hvd.size(), hvd.rank())
        logging.debug(
            f"Sharded dataset to index {hvd.rank()} of {hvd.size()}.")
        return dataset
    def _set_data_loaders(self) -> None:
        skip_batches = self.env.initial_workload.total_batches_processed

        nreplicas = hvd.size() if self.hvd_config.use else 1
        rank = hvd.rank() if self.hvd_config.use else 0

        # TODO: the number of ways a user could get this wrong is alarming.  Right now we don't
        # have any validation, but we should add some.  Maybe deprecate the old way?  Or mark the
        # new way as "advanced"?
        train_data = self.trial.build_training_data_loader()
        if isinstance(train_data, pytorch.DataLoader):
            # Old-API, a user-provided det.pytorch.DataLoader.
            self.training_loader = train_data.get_data_loader(
                repeat=True,
                skip=skip_batches,
                num_replicas=nreplicas,
                rank=rank)
        else:
            # New-API, assume the user called context.make_training_batch_sampler.
            self.training_loader = train_data

        self.context._epoch_len = len(self.training_loader)

        validation_data = self.trial.build_validation_data_loader()
        if self._evaluate_batch_defined():
            if isinstance(validation_data, pytorch.DataLoader):
                # Old-API, a user-provided det.pytorch.DataLoader.
                self.validation_loader = validation_data.get_data_loader(
                    repeat=False, skip=0, num_replicas=nreplicas, rank=rank)
            else:
                # New-API, assume the user called context.make_validation_batch_sampler.
                self.validation_loader = validation_data
        elif self.is_chief:
            if isinstance(validation_data, pytorch.DataLoader):
                # Old-API, a user-provided det.pytorch.DataLoader.
                self.validation_loader = validation_data.get_data_loader(
                    repeat=False, skip=0, num_replicas=1, rank=0)
            else:
                # Oh shit, I hope the user didn't call make_validation_batch_sampler; that would
                # be bad for them here.
                self.validation_loader = validation_data
Exemplo n.º 13
0
    def wrap_dataset(self, dataset: Any) -> Any:
        """
        This should be used to wrap ``tf.data.Dataset`` objects immediately after
        they have been created. Users should use the output of this wrapper as the
        new instance of their dataset. If users create multiple datasets (e.g.,
        one for training and one for testing), users should wrap each dataset
        independently.

        Args:
            dataset: tf.data.Dataset
        """
        self.dataset_initialized = True
        if not self.hvd_config.use or not isinstance(dataset, tf.data.Dataset):
            return dataset

        hvd.require_horovod_type("tensorflow.keras",
                                 "TFKerasContext.wrap_dataset was called.")
        dataset = dataset.shard(hvd.size(), hvd.rank())
        logging.debug(
            f"Sharded dataset to index {hvd.rank()} of {hvd.size()}.")
        return dataset
Exemplo n.º 14
0
    def _train_for_step(self, step_id: int, num_batches: int,
                        total_batches_processed: int) -> workload.Response:
        check.gt(step_id, 0)
        self.context.reset_reducers()

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        for model in self.context.models:
            model.train()

        start = total_batches_processed
        end = start + num_batches

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            batch_start_time = time.time()
            self.prof.update_batch_idx(batch_idx)
            with self.prof.record_timing("dataloader_next"):
                batch = next(self.training_iterator)
            batch_inputs = self.trial.get_batch_length(batch)
            num_inputs += batch_inputs

            with self.prof.record_timing("to_device"):
                batch = self.context.to_device(batch)

            self.context._current_batch_idx = batch_idx
            if self.context.is_epoch_start():
                for callback in self.callbacks.values():
                    with self.prof.record_timing(
                            f"callbacks.{callback.__class__.__name__}.on_training_epoch_start"
                    ):
                        callback.on_training_epoch_start()
            self.context._loss_ids = {}

            with self.prof.record_timing("train_batch"):
                if self.context.profiler:
                    with self.context.profiler as torch_profiler:
                        tr_metrics = self.trial.train_batch(
                            batch=batch,
                            epoch_idx=self.get_epoch_idx(batch_idx),
                            batch_idx=batch_idx,
                        )
                        torch_profiler.step()
                else:
                    tr_metrics = self.trial.train_batch(
                        batch=batch,
                        epoch_idx=self.get_epoch_idx(batch_idx),
                        batch_idx=batch_idx,
                    )
            if self._should_update_scaler():
                self.context._scaler.update()
            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}
            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )

            # Step learning rate of a pytorch.LRScheduler.
            with self.prof.record_timing("step_lr_schedulers"):
                for lr_scheduler in self.context.lr_schedulers:
                    self._auto_step_lr_scheduler_per_batch(
                        batch_idx, lr_scheduler)

            with self.prof.record_timing("from_device"):
                for name, metric in tr_metrics.items():
                    # Convert PyTorch metric values to NumPy, so that
                    # `det.util.encode_json` handles them properly without
                    # needing a dependency on PyTorch.
                    if isinstance(metric, torch.Tensor):
                        metric = metric.cpu().detach().numpy()
                    tr_metrics[name] = metric

            batch_dur = time.time() - batch_start_time
            samples_per_second = batch_inputs / batch_dur
            self.prof.record_metric("samples_per_second", samples_per_second)
            per_batch_metrics.append(tr_metrics)

        # Aggregate and reduce training metrics from all the training processes.
        if self.hvd_config.use and self.hvd_config.average_training_metrics:
            with self.prof.record_timing("average_training_metrics"):
                per_batch_metrics = self._average_training_metrics(
                    per_batch_metrics)
        if self.hvd_config.use:
            num_inputs *= hvd.size()
        metrics = det.util.make_metrics(num_inputs, per_batch_metrics)

        # Ignore batch_metrics entirely for custom reducers; there's no guarantee that per-batch
        # metrics are even logical for a custom reducer.
        with self.prof.record_timing("reduce_metrics"):
            metrics["avg_metrics"].update(
                self._convert_metrics_to_numpy(
                    self.context.reduce_metrics(for_training=True)))

        if not self.is_chief:
            # The training metrics are reported only in the chief process.
            return workload.Skipped()

        logging.debug(
            f"Done training step: {num_inputs} records in {num_batches} batches."
        )

        return metrics
Exemplo n.º 15
0
    def _train_for_step(self, step_id: int,
                        batches_per_step: int) -> workload.Response:
        check.gt(step_id, 0)

        step_idx = step_id - 1
        start = step_idx * batches_per_step
        end = start + batches_per_step

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        self.model.train()

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            batch = next(self.training_iterator)
            num_inputs += data_length(batch)

            batch = self._to_device(batch)
            # Forward pass.
            tr_metrics = self.trial.train_batch(
                batch=batch,
                model=self.model,
                epoch_idx=self.get_epoch_idx(batch_idx),
                batch_idx=batch_idx,
            )

            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}

            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                "mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )
            check.is_in("loss", tr_metrics.keys(),
                        'Please include "loss" in you training metrics.')

            # Backwards pass.
            loss = tr_metrics["loss"]
            communicate_and_update = (
                batch_idx + 1) % self.hvd_config.aggregation_frequency == 0
            if self.use_amp():
                with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                    if self.hvd_config.use and communicate_and_update:
                        self.optimizer.synchronize()
            else:
                loss.backward()

            if communicate_and_update:
                parameters = (self.model.parameters() if not self.use_amp()
                              else apex.amp.master_params(self.optimizer))

                if self.hvd_config.average_aggregated_gradients:
                    self._average_gradients(
                        parameters=parameters,
                        divisor=self.hvd_config.aggregation_frequency)

                self._clip_grads(parameters)

                if self.hvd_config.use and self.use_amp():
                    with self.optimizer.skip_synchronize():
                        self.optimizer.step()
                else:
                    self.optimizer.step()
                self.optimizer.zero_grad()

                if self.lr_helper.should_step_lr(
                        batches_completed=batch_idx + 1,
                        epoch_length=len(self.training_loader),
                        aggregation_frequency=self.hvd_config.
                        aggregation_frequency,
                ):
                    self.lr_helper.step()

            for name, metric in tr_metrics.items():
                # Convert PyTorch metric values to NumPy, so that
                # `det.util.encode_json` handles them properly without
                # needing a dependency on PyTorch.
                if isinstance(metric, torch.Tensor):
                    metric = metric.cpu().detach().numpy()
                tr_metrics[name] = metric

            check.is_in("loss", tr_metrics,
                        'Please include "loss" in your training metrics.')
            per_batch_metrics.append(tr_metrics)

        if self.hvd_config.use and self.hvd_config.average_training_metrics:
            per_batch_metrics = self._average_training_metrics(
                per_batch_metrics)

        if not self.is_chief:
            return workload.Skipped()

        if self.hvd_config.use:
            num_inputs *= hvd.size()

        logging.debug(
            f"Done training step: {num_inputs} records in {batches_per_step} batches."
        )
        return det.util.make_metrics(num_inputs, per_batch_metrics)
Exemplo n.º 16
0
    def _train_for_step(self, step_id: int, num_batches: int,
                        total_batches_processed: int) -> workload.Response:
        check.gt(step_id, 0)

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        for model in self.context.models:
            model.train()

        start = total_batches_processed
        end = start + num_batches

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            batch = next(self.training_iterator)
            num_inputs += data_length(batch)
            batch = self.context._to_device(batch)

            self.context._current_batch_idx = batch_idx
            self.context._loss_ids = {}
            tr_metrics = self.trial.train_batch(
                batch=batch,
                model=self.context.models[0],
                epoch_idx=self.get_epoch_idx(batch_idx),
                batch_idx=batch_idx,
            )
            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}
            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )
            check.is_in("loss", tr_metrics.keys(),
                        'Please include "loss" in you training metrics.')

            # Step learning rate of a LRScheduler.
            for lr_scheduler in self.context.lr_schedulers:
                self._auto_step_lr_scheduler_per_batch(batch_idx, lr_scheduler)

            for name, metric in tr_metrics.items():
                # Convert PyTorch metric values to NumPy, so that
                # `det.util.encode_json` handles them properly without
                # needing a dependency on PyTorch.
                if isinstance(metric, torch.Tensor):
                    metric = metric.cpu().detach().numpy()
                tr_metrics[name] = metric

            check.is_in("loss", tr_metrics,
                        'Please include "loss" in your training metrics.')
            per_batch_metrics.append(tr_metrics)

        # Aggregate and reduce training metrics from all the training processes.
        if self.hvd_config.use and self.hvd_config.average_training_metrics:
            per_batch_metrics = self._average_training_metrics(
                per_batch_metrics)
        if self.hvd_config.use:
            num_inputs *= hvd.size()
        metrics = det.util.make_metrics(num_inputs, per_batch_metrics)

        if not self.is_chief:
            # The training metrics are reported only in the chief process.
            return workload.Skipped()

        logging.debug(
            f"Done training step: {num_inputs} records in {num_batches} batches."
        )

        return metrics
Exemplo n.º 17
0
    def _train_for_step(self, step_id: int, batches_per_step: int) -> workload.Response:
        check.gt(step_id, 0)

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        self.context.model.train()

        for callback in self.callbacks.values():
            callback.on_train_step_start(step_id)

        step_idx = step_id - 1
        start = step_idx * batches_per_step
        end = start + batches_per_step

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            batch = next(self.training_iterator)
            num_inputs += data_length(batch)

            batch = self._to_device(batch)
            # Forward pass.
            tr_metrics = self.trial.train_batch(
                batch=batch,
                model=self.context.model,
                epoch_idx=self.get_epoch_idx(batch_idx),
                batch_idx=batch_idx,
            )

            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}

            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                "mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )
            check.is_in("loss", tr_metrics.keys(), 'Please include "loss" in you training metrics.')

            # Backwards pass.
            loss = tr_metrics["loss"]
            communicate_and_update = (batch_idx + 1) % self.hvd_config.aggregation_frequency == 0
            if self.use_amp():
                with apex.amp.scale_loss(loss, self.context.optimizer) as scaled_loss:
                    scaled_loss.backward()
                    if self.hvd_config.use and communicate_and_update:
                        # When using horovod, we need to finish communicating gradient
                        # updates before they are unscaled which happens when we exit
                        # of this context manager.
                        self.context.optimizer.synchronize()
            else:
                loss.backward()

                # Communication needs to be synchronized so that is completed
                # before we apply gradient clipping and `step()`.
                if communicate_and_update and self.hvd_config.use:
                    self.context.optimizer.synchronize()

            if communicate_and_update:
                parameters = (
                    self.context.model.parameters()
                    if not self.use_amp()
                    else apex.amp.master_params(self.context.optimizer)
                )

                if self.hvd_config.average_aggregated_gradients:
                    self._average_gradients(
                        parameters=parameters, divisor=self.hvd_config.aggregation_frequency
                    )

                # TODO: Remove this check in v0.12.8.
                check.false(
                    self.env.hparams.get("clip_grad_l2_norm", None)
                    or self.env.hparams.get("clip_grad_val", None),
                    "Please specify gradient clipping via callbacks.",
                )

                for callback in self.callbacks.values():
                    callback.on_before_optimizer_step(parameters)

                if self.hvd_config.use:
                    with self.context.optimizer.skip_synchronize():
                        self.context.optimizer.step()
                else:
                    self.context.optimizer.step()
                self.context.optimizer.zero_grad()

                # Step learning rate of a LRScheduler.
                if self.context.lr_scheduler is not None:
                    self._auto_step_lr_scheduler_per_batch(batch_idx, self.context.lr_scheduler)

            for name, metric in tr_metrics.items():
                # Convert PyTorch metric values to NumPy, so that
                # `det.util.encode_json` handles them properly without
                # needing a dependency on PyTorch.
                if isinstance(metric, torch.Tensor):
                    metric = metric.cpu().detach().numpy()
                tr_metrics[name] = metric

            check.is_in("loss", tr_metrics, 'Please include "loss" in your training metrics.')
            per_batch_metrics.append(tr_metrics)

        if self.hvd_config.use and self.hvd_config.average_training_metrics:
            per_batch_metrics = self._average_training_metrics(per_batch_metrics)

        if self.hvd_config.use:
            num_inputs *= hvd.size()

        metrics = det.util.make_metrics(num_inputs, per_batch_metrics)

        for callback in self.callbacks.values():
            callback.on_train_step_end(step_id, metrics)

        if not self.is_chief:
            return workload.Skipped()

        logging.debug(f"Done training step: {num_inputs} records in {batches_per_step} batches.")

        return metrics
    def _train_for_step(self, step_id: int, num_batches: int,
                        total_batches_processed: int) -> workload.Response:
        check.gt(step_id, 0)
        self.context.experimental.reset_reducers()

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        for model in self.context.models:
            model.train()

        start = total_batches_processed
        end = start + num_batches

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            batch = next(self.training_iterator)

            ## old code:
            # num_inputs += pytorch.data_length(batch)
            # batch = self.context.to_device(batch)

            num_inputs += self.trial._records_in_batch(batch)
            batch = self.trial._batch_to_device(batch, self.context)

            self.context._current_batch_idx = batch_idx
            self.context._loss_ids = {}
            tr_metrics = self.trial.train_batch(
                batch=batch,
                epoch_idx=self.get_epoch_idx(batch_idx),
                batch_idx=batch_idx,
            )
            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}
            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )

            # Step learning rate of a pytorch.LRScheduler.
            for lr_scheduler in self.context.lr_schedulers:
                self._auto_step_lr_scheduler_per_batch(batch_idx, lr_scheduler)

            for name, metric in tr_metrics.items():
                # Convert PyTorch metric values to NumPy, so that
                # `det.util.encode_json` handles them properly without
                # needing a dependency on PyTorch.
                if isinstance(metric, torch.Tensor):
                    metric = metric.cpu().detach().numpy()
                tr_metrics[name] = metric

            per_batch_metrics.append(tr_metrics)

        # Aggregate and reduce training metrics from all the training processes.
        if self.hvd_config.use and self.hvd_config.average_training_metrics:
            per_batch_metrics = self._average_training_metrics(
                per_batch_metrics)
        if self.hvd_config.use:
            num_inputs *= hvd.size()
        metrics = det.util.make_metrics(num_inputs, per_batch_metrics)

        # Ignore batch_metrics entirely for custom reducers; there's no guarantee that per-batch
        # metrics are even logical for a custom reducer.
        metrics["avg_metrics"].update(
            self._convert_metrics_to_numpy(
                self.context.experimental.reduce_metrics(for_training=True)))

        if not self.is_chief:
            # The training metrics are reported only in the chief process.
            return workload.Skipped()

        logging.debug(
            f"Done training step: {num_inputs} records in {num_batches} batches."
        )

        return metrics
Exemplo n.º 19
0
    def _compute_validation_metrics(self) -> workload.Response:
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        self.model.eval()
        num_inputs = 0
        metrics = {}  # type: Optional[Dict[str, Any]]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for batch in self.validation_loader:
                batch = self._to_device(batch)
                num_inputs += data_length(batch)

                vld_metrics = self.trial.evaluate_batch(batch=batch,
                                                        model=self.model)
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    self._convert_metrics_to_numpy(vld_metrics))

            keys = cast(Any, keys)
            metrics = self._reduce_metrics(
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=self._prepare_metrics_reducers(keys=keys),
            )

            if self.hvd_config.use:
                num_inputs *= hvd.size()

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader, model=self.model)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = self._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        if not self.is_chief:
            return workload.Skipped()

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.experimental.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        for callback in self.callbacks.values():
            logging.warning(
                "on_validation_step_start is now deprecated, please use on_validation_start instead"
            )
            callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for batch in self.validation_loader:
                batch = self.context.to_device(batch)
                num_inputs += pytorch.data_length(batch)

                vld_metrics = self.trial.evaluate_batch(batch=batch)
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    self._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            metrics = self._reduce_metrics(
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=self._prepare_metrics_reducers(keys=keys),
            )

            if self.hvd_config.use:
                num_inputs *= hvd.size()

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = self._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            self._convert_metrics_to_numpy(
                self.context.experimental.reduce_metrics(for_training=False)))

        if self.hvd_config.use and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            logging.warning(
                "on_validation_step_end is now deprecated, please use on_validation_end instead"
            )
            callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return workload.Skipped()

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
Exemplo n.º 21
0
    def _init_shard(self) -> None:
        if not self._hvd_config.use:
            return

        self._shard_rank = hvd.rank()
        self._num_shards = hvd.size()