def as_batches(
     self,
     batches: Optional[int] = None,
     records: Optional[int] = None,
     epochs: Optional[int] = None,
 ) -> int:
     if sum((batches is not None, records is not None, epochs
             is not None)) != 1:
         raise ValueError(
             f"invalid length: batches={batches} records={records} epochs={epochs}"
         )
     if batches is not None:
         return batches
     if records is not None:
         check.gt(self.global_batch_size, 0,
                  "global_batch_size must be positive")
         return max(records // self.global_batch_size, 1)
     if epochs is not None:
         check.is_instance(self.records_per_epoch, int,
                           "length must be an integer")
         assert self.records_per_epoch is not None
         check.gt(self.global_batch_size, 0,
                  "global_batch_size must be positive")
         return max(
             (epochs * self.records_per_epoch) // self.global_batch_size, 1)
     # Make mypy happy.
     raise ValueError("invalid length")
Exemplo n.º 2
0
    def _compute_validation_metrics(self) -> workload.Response:
        metrics = self._launch_evaluate()
        num_inputs = self.multiplexer.get_test_inputs()

        if self.hvd_config.use:
            # Use a global ZMQ barrier here because we have observed cases where hvd.allreduce
            # may hang when called minutes apart by different workers which may happen if
            # workers complete evaluation at different speeds.
            self._global_barrier()

            num_inputs = hvd.allreduce(num_inputs,
                                       average=False,
                                       name="validation_num_inputs")
            if isinstance(num_inputs, EagerTensor):
                # Horovod will promote an int to a tensor in eager mode.
                num_inputs = num_inputs.numpy()

        metrics = self._allreduce_logs(metrics)
        check.gt(len(metrics), 0)

        self.multiplexer._test_end(metrics)

        if not self.is_chief:
            return workload.Skipped()

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
Exemplo n.º 3
0
    def _compute_validation_metrics(self) -> workload.Response:
        validation_start_time = time.time()
        metrics = self._launch_evaluate()
        num_inputs, num_batches = self.multiplexer.get_test_inputs()

        if self.context.distributed.size > 1:
            # Use a global ZMQ barrier here because we have observed cases where hvd.allreduce
            # may hang when called minutes apart by different workers which may happen if
            # workers complete evaluation at different speeds.
            _ = self.context.distributed.gather(None)

            num_inputs = hvd.allreduce(num_inputs, average=False, name="validation_num_inputs")
            if isinstance(num_inputs, EagerTensor):
                # Horovod will promote an int to a tensor in eager mode.
                num_inputs = num_inputs.numpy()
            num_batches = hvd.allreduce(num_batches, average=False, name="validation_num_batches")
            if isinstance(num_batches, EagerTensor):
                num_batches = num_batches.numpy()

        metrics = self._allreduce_logs(metrics)
        check.gt(len(metrics), 0)

        self.multiplexer._test_end(metrics)

        if not self.is_chief:
            return {}

        step_duration = time.time() - validation_start_time
        logging.info(det.util.make_timing_log("validated", step_duration, num_inputs, num_batches))

        self.metric_writer.on_validation_step_end(self.steps_completed, metrics)
        self.upload_tb_files()
        return {"num_inputs": num_inputs, "validation_metrics": metrics}
Exemplo n.º 4
0
    def __init__(
        self,
        sequence: tf.keras.utils.Sequence,
        sampler: _Sampler,
        repeat: bool,
        workers: int,
        max_queue_size: int,
    ):
        self.sequence = sequence
        self.sampler = sampler
        self.repeat = repeat
        self.max_queue_size = max_queue_size
        check.gt(max_queue_size, 0, "max_queue_size must be greater than zero")

        # Coordination logic.
        self.order = 0
        self.requested = collections.deque()  # type: Deque[int]
        self.received = {}  # type: Dict[int, Any]
        self.started = False
        self.stopped = False
        self.index_iter = None  # type: Optional[Iterator]

        # Interthread/interprocess communications.
        self.queries = self.queue_class()()
        self.answers = self.queue_class()()

        self.workers = [
            self.worker_class()(target=_worker,
                                args=(self.sequence, self.queries,
                                      self.answers)) for _ in range(workers)
        ]
Exemplo n.º 5
0
    def __init__(self, batch_sampler: torch.utils.data.BatchSampler,
                 num_workers: int, rank: int) -> None:
        check.gt(rank, -1, "rank must be non-negative")
        check.gt(num_workers, 0, "num_workers must be positive")
        check.lt(rank, num_workers, "rank must be less than num_workers")

        self.batch_sampler = batch_sampler
        self.num_workers = num_workers
        self.rank = rank
Exemplo n.º 6
0
    def _send_recv_workload(self, wkld: workload.Workload,
                            args: List[Any]) -> workload.Response:
        # Broadcast every workload to every worker on this machine.
        self.broadcast_server.broadcast((wkld, args))

        if wkld.kind == workload.Workload.Kind.TERMINATE:
            # Do not perform health checks once worker have been instructed to terminate.
            self._worker_process_ids = []

        try:
            responses, exception_received = self.broadcast_server.gather_with_polling(
                self._health_check)
        except det.errors.WorkerError:
            if wkld.kind == workload.Workload.Kind.TERMINATE:
                return {}
            raise

        if exception_received:
            raise det.errors.WorkerError("Training process died.")

        # Find the response from the chief worker for the trial (the only non-SkippedWorkload). The
        # chief may report to another container, in which case we will only have SkippedWorkloads.
        chief_worker_response = None  # Optional[workload.Metrics]
        for response in responses:
            if isinstance(response, workload.Skipped):
                continue
            # Any other response must be a Dict[str, Any]-like object.
            check.is_instance(
                response, dict,
                f"Received non-metrics object from worker: {response}")
            # There should only be one chief response.
            # Special case InvalidHP messages
            if chief_worker_response != {
                    "metrics": {},
                    "stop_requested": False,
                    "invalid_hp": True,
                    "init_invalid_hp": False,
            }:
                check.is_none(
                    chief_worker_response,
                    "Received multiple non-SkippedWorkload messages.")
            chief_worker_response = cast(Dict[str, Any], response)

        # Confirm that if we have did not see a chief response then we are not the chief machine.
        if chief_worker_response is None:
            check.gt(
                self.rendezvous_info.get_rank(),
                0,
                "Received SkippedWorkload message from chief worker.",
            )

        return workload.Skipped(
        ) if chief_worker_response is None else chief_worker_response
Exemplo n.º 7
0
 def _init_device(self) -> None:
     self.n_gpus = len(self.env.container_gpus)
     if self.hvd_config.use:
         check.gt(self.n_gpus, 0)
         # We launch a horovod process per GPU. Each process
         # needs to bind to a unique GPU.
         self.device = torch.device(hvd.local_rank())
         torch.cuda.set_device(self.device)
     elif self.n_gpus > 0:
         self.device = torch.device("cuda", 0)
     else:
         self.device = torch.device("cpu")
     check.is_not_none(self.device)
Exemplo n.º 8
0
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_start,
                                  pytorch.PyTorchCallback):
                logging.warning("on_validation_step_start is now deprecated, "
                                "please use on_validation_start instead")
                callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for callback in self.callbacks.values():
                callback.on_validation_epoch_start()
            for idx, batch in enumerate(self.validation_loader):
                batch = self.context.to_device(batch)
                num_inputs += self.trial.get_batch_length(batch)

                if has_param(self.trial.evaluate_batch, "batch_idx", 2):
                    vld_metrics = self.trial.evaluate_batch(batch=batch,
                                                            batch_idx=idx)
                else:
                    vld_metrics = self.trial.evaluate_batch(
                        batch=batch)  # type: ignore
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    self._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            for callback in self.callbacks.values():
                callback.on_validation_epoch_end(batch_metrics)

            metrics = self._reduce_metrics(
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=self._prepare_metrics_reducers(keys=keys),
            )

            if self.hvd_config.use:
                num_inputs *= hvd.size()

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = self._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            self._convert_metrics_to_numpy(
                self.context.reduce_metrics(for_training=False)))

        if self.hvd_config.use and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_end,
                                  pytorch.PyTorchCallback):
                logging.warning(
                    "on_validation_step_end is now deprecated, please use on_validation_end instead"
                )
                callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return workload.Skipped()

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
Exemplo n.º 9
0
    def _train_for_step(self, step_id: int, num_batches: int,
                        total_batches_processed: int) -> workload.Response:
        check.gt(step_id, 0)
        self.context.reset_reducers()

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        for model in self.context.models:
            model.train()

        start = total_batches_processed
        end = start + num_batches

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            batch_start_time = time.time()
            self.prof.update_batch_idx(batch_idx)
            with self.prof.record_timing("dataloader_next"):
                batch = next(self.training_iterator)
            batch_inputs = self.trial.get_batch_length(batch)
            num_inputs += batch_inputs

            with self.prof.record_timing("to_device"):
                batch = self.context.to_device(batch)

            self.context._current_batch_idx = batch_idx
            if self.context.is_epoch_start():
                for callback in self.callbacks.values():
                    with self.prof.record_timing(
                            f"callbacks.{callback.__class__.__name__}.on_training_epoch_start"
                    ):
                        callback.on_training_epoch_start()
            self.context._loss_ids = {}

            with self.prof.record_timing("train_batch"):
                if self.context.profiler:
                    with self.context.profiler as torch_profiler:
                        tr_metrics = self.trial.train_batch(
                            batch=batch,
                            epoch_idx=self.get_epoch_idx(batch_idx),
                            batch_idx=batch_idx,
                        )
                        torch_profiler.step()
                else:
                    tr_metrics = self.trial.train_batch(
                        batch=batch,
                        epoch_idx=self.get_epoch_idx(batch_idx),
                        batch_idx=batch_idx,
                    )
            if self._should_update_scaler():
                self.context._scaler.update()
            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}
            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )

            # Step learning rate of a pytorch.LRScheduler.
            with self.prof.record_timing("step_lr_schedulers"):
                for lr_scheduler in self.context.lr_schedulers:
                    self._auto_step_lr_scheduler_per_batch(
                        batch_idx, lr_scheduler)

            with self.prof.record_timing("from_device"):
                for name, metric in tr_metrics.items():
                    # Convert PyTorch metric values to NumPy, so that
                    # `det.util.encode_json` handles them properly without
                    # needing a dependency on PyTorch.
                    if isinstance(metric, torch.Tensor):
                        metric = metric.cpu().detach().numpy()
                    tr_metrics[name] = metric

            batch_dur = time.time() - batch_start_time
            samples_per_second = batch_inputs / batch_dur
            self.prof.record_metric("samples_per_second", samples_per_second)
            per_batch_metrics.append(tr_metrics)

        # Aggregate and reduce training metrics from all the training processes.
        if self.hvd_config.use and self.hvd_config.average_training_metrics:
            with self.prof.record_timing("average_training_metrics"):
                per_batch_metrics = self._average_training_metrics(
                    per_batch_metrics)
        if self.hvd_config.use:
            num_inputs *= hvd.size()
        metrics = det.util.make_metrics(num_inputs, per_batch_metrics)

        # Ignore batch_metrics entirely for custom reducers; there's no guarantee that per-batch
        # metrics are even logical for a custom reducer.
        with self.prof.record_timing("reduce_metrics"):
            metrics["avg_metrics"].update(
                self._convert_metrics_to_numpy(
                    self.context.reduce_metrics(for_training=True)))

        if not self.is_chief:
            # The training metrics are reported only in the chief process.
            return workload.Skipped()

        logging.debug(
            f"Done training step: {num_inputs} records in {num_batches} batches."
        )

        return metrics
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        step_start_time = time.time()

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_start,
                                  pytorch.PyTorchCallback):
                logging.warning("on_validation_step_start is now deprecated, "
                                "please use on_validation_start instead")
                callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for callback in self.callbacks.values():
                callback.on_validation_epoch_start()
            for idx, batch in enumerate(self.validation_loader):
                if self.context.experimental._auto_to_device:
                    batch = self.context.to_device(batch)
                num_inputs += self.trial.get_batch_length(batch)

                if has_param(self.trial.evaluate_batch, "batch_idx", 2):
                    vld_metrics = self.trial.evaluate_batch(batch=batch,
                                                            batch_idx=idx)
                else:
                    vld_metrics = self.trial.evaluate_batch(
                        batch=batch)  # type: ignore
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    pytorch._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            for callback in self.callbacks.values():
                callback.on_validation_epoch_end(batch_metrics)

            metrics = pytorch._reduce_metrics(
                self.context.distributed,
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=pytorch._prepare_metrics_reducers(
                    self.trial.evaluation_reducer(), keys=keys),
            )

            # Gather a list of per-worker (num_inputs, num_batches) tuples.
            input_counts = self.context.distributed.gather(
                (num_inputs, idx + 1))
            if self.context.distributed.rank == 0:
                assert input_counts is not None
                # Reshape and sum.
                num_inputs, num_batches = [sum(n) for n in zip(*input_counts)]

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = pytorch._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            pytorch._convert_metrics_to_numpy(
                self.context.reduce_metrics(for_training=False)))

        if self.context.distributed.size > 1 and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_end,
                                  pytorch.PyTorchCallback):
                logging.warning(
                    "on_validation_step_end is now deprecated, please use on_validation_end instead"
                )
                callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return {}

        # Skip reporting timings if evaluate_full_dataset() was defined.  This is far less common
        # than evaluate_batch() and we can't know how the user processed their validation data.
        if self._evaluate_batch_defined():
            step_duration = time.time() - step_start_time
            logging.info(
                det.util.make_timing_log("validated", step_duration,
                                         num_inputs, num_batches))

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
    def _train_for_step(self, step_id: int, num_batches: int,
                        total_batches_processed: int) -> workload.Response:
        self.prof.set_training(True)
        check.gt(step_id, 0)
        step_start_time = time.time()
        self.context.reset_reducers()

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        for model in self.context.models:
            model.train()

        start = total_batches_processed
        end = start + num_batches

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            self.steps_completed += 1
            batch_start_time = time.time()
            self.prof.update_batch_idx(batch_idx)
            with self.prof.record_timing("dataloader_next",
                                         requires_sync=False):
                batch = next(self.training_iterator)
            batch_inputs = self.trial.get_batch_length(batch)
            num_inputs += batch_inputs

            if self.context.experimental._auto_to_device:
                with self.prof.record_timing("to_device", accumulate=True):
                    batch = self.context.to_device(batch)

            self.context._current_batch_idx = batch_idx
            epoch_idx = self.get_epoch_idx(batch_idx)
            if self.context.is_epoch_start():
                for callback in self.callbacks.values():
                    with self.prof.record_timing(
                            f"callbacks.{callback.__class__.__name__}.on_training_epoch_start"
                    ):
                        sig = signature(callback.on_training_epoch_start)
                        if sig.parameters:
                            callback.on_training_epoch_start(epoch_idx)
                        else:
                            logging.warning(
                                "on_training_epoch_start() without parameters is deprecated"
                                " since 0.17.8. Please add epoch_idx parameter."
                            )
                            callback.on_training_epoch_start(
                            )  # type: ignore[call-arg]

            self.context._loss_ids = {}

            with self.prof.record_timing("train_batch", requires_sync=False):
                if self.context.profiler:
                    with self.context.profiler as torch_profiler:
                        tr_metrics = self.trial.train_batch(
                            batch=batch,
                            epoch_idx=epoch_idx,
                            batch_idx=batch_idx,
                        )
                        torch_profiler.step()
                else:
                    tr_metrics = self.trial.train_batch(
                        batch=batch,
                        epoch_idx=epoch_idx,
                        batch_idx=batch_idx,
                    )
            if self._should_update_scaler():
                self.context._scaler.update()
            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}
            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )

            # Step learning rate of a pytorch.LRScheduler.
            with self.prof.record_timing("step_lr_schedulers"):
                for lr_scheduler in self.context.lr_schedulers:
                    self._auto_step_lr_scheduler_per_batch(
                        batch_idx, lr_scheduler)

            with self.prof.record_timing("from_device"):
                for name, metric in tr_metrics.items():
                    # Convert PyTorch metric values to NumPy, so that
                    # `det.util.encode_json` handles them properly without
                    # needing a dependency on PyTorch.
                    if isinstance(metric, torch.Tensor):
                        metric = metric.cpu().detach().numpy()
                    tr_metrics[name] = metric

            batch_dur = time.time() - batch_start_time
            samples_per_second = batch_inputs / batch_dur
            samples_per_second *= self.context.distributed.size
            self.prof.record_metric("samples_per_second", samples_per_second)
            per_batch_metrics.append(tr_metrics)

            if self.context.is_epoch_end():
                for callback in self.callbacks.values():
                    with self.prof.record_timing(
                            f"callbacks.{callback.__class__.__name__}.on_training_epoch_end"
                    ):
                        callback.on_training_epoch_end(epoch_idx)

        # Aggregate and reduce training metrics from all the training processes.
        if self.context.distributed.size > 1 and self.context._average_training_metrics:
            with self.prof.record_timing("average_training_metrics"):
                per_batch_metrics = pytorch._combine_and_average_training_metrics(
                    self.context.distributed, per_batch_metrics)
        num_inputs *= self.context.distributed.size
        metrics = det.util.make_metrics(num_inputs, per_batch_metrics)

        # Ignore batch_metrics entirely for custom reducers; there's no guarantee that per-batch
        # metrics are even logical for a custom reducer.
        with self.prof.record_timing("reduce_metrics"):
            metrics["avg_metrics"].update(
                pytorch._convert_metrics_to_numpy(
                    self.context.reduce_metrics(for_training=True)))

        if not self.is_chief:
            # The training metrics are reported only in the chief process.
            return {}

        step_duration = time.time() - step_start_time
        logging.info(
            det.util.make_timing_log("trained", step_duration, num_inputs,
                                     num_batches))

        return metrics