Пример #1
0
    def wrap_model(self, model: torch.nn.Module) -> torch.nn.Module:
        """Returns a wrapped model."""

        if self.env.managed_training:
            check.false(self._use_apex, "Must call wrap_model() before configure_apex_amp.")

            model = model.to(self.device)
            if not self.hvd_config.use and self.n_gpus > 1:
                check.eq(
                    self.hvd_config.aggregation_frequency,
                    1,
                    "Please enable `optimized_parallel` to use aggregation "
                    "frequency greater than 1 for single machine multi-GPU "
                    "training.",
                )
                model = nn.DataParallel(model)
                logging.debug("Initialized model for native parallel training.")

        model_id = len(self.models)
        self._main_model.__setattr__(f"model_{model_id}", model)

        if self.experimental._auto_amp:
            model = self.autocast_forward_pass(model)

        self.models.append(model)
        return model
Пример #2
0
 def _run(self) -> None:
     for w, args, response_func in self.workloads:
         if w.kind == workload.Workload.Kind.RUN_STEP:
             try:
                 response_func(
                     util.wrap_metrics(
                         self._train_for_step(
                             w.step_id,
                             w.num_batches,
                             w.total_batches_processed,
                         ),
                         self.context.get_stop_requested(),
                         invalid_hp=False,
                         init_invalid_hp=False,
                     )
                 )
             except det.InvalidHP as e:
                 logging.info(
                     "Invalid hyperparameter exception in trial train step: {}".format(e)
                 )
                 response_func(
                     util.wrap_metrics(
                         {},
                         self.context.get_stop_requested(),
                         invalid_hp=True,
                         init_invalid_hp=False,
                     )
                 )
         elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
             try:
                 response_func(
                     util.wrap_metrics(
                         self._compute_validation_metrics(),
                         self.context.get_stop_requested(),
                         invalid_hp=False,
                         init_invalid_hp=False,
                     )
                 )
             except det.InvalidHP as e:
                 logging.info(
                     "Invalid hyperparameter exception in trial validation step: {}".format(e)
                 )
                 response_func(
                     util.wrap_metrics(
                         {},
                         self.context.get_stop_requested(),
                         invalid_hp=True,
                         init_invalid_hp=False,
                     )
                 )
         elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
             check.eq(len(args), 1)
             check.is_instance(args[0], pathlib.Path)
             path = cast(pathlib.Path, args[0])
             response_func(self._save(path))
         elif w.kind == workload.Workload.Kind.TERMINATE:
             response_func({} if self.is_chief else workload.Skipped())
             break
         else:
             raise AssertionError("Unexpected workload: {}".format(w.kind))
Пример #3
0
    def barrier(self,
                num_connections: int,
                message: Any = None,
                timeout: Optional[int] = None) -> List[Any]:
        """
        This is a one-sided barrier, where the chief blocks until
        all non-chief trial containers have sent a message.
        """
        check.eq(len(self.sockets), 1)
        messages = []  # type: List[Any]
        start_time = time.time()

        for _ in range(num_connections):
            if timeout:
                message_received, barrier_message = self.receive_non_blocking(
                    send_rank=0, deadline=start_time + timeout)

                if not message_received:
                    return messages

            else:
                barrier_message = self.receive_blocking(0)

            check.is_instance(barrier_message, _OneSidedBarrier)
            messages.append(barrier_message.message)
            self.sockets[0].send_pyobj(
                _OneSidedBarrier(message=message))  # type: ignore

        return messages
Пример #4
0
    def require_horovod_type(self, horovod_type: str, reason: str) -> None:
        """
        Declare the required type of horovod and give a unique reason as to why it is required.

        The reason makes for clear error reporting if require_horovod_type() is called a second
        time but with a different type.
        """

        known_types = {"tensorflow", "tensorflow.keras", "torch"}
        check.is_in(horovod_type, known_types,
                    "Unknown horovod type requested.")

        if self._poly_hvd_type is not None:
            check.eq(
                horovod_type,
                self._poly_hvd_type,
                f"require_horovod_type() called with with type {horovod_type} after a previous "
                f"call with type {self._poly_hvd_type} in the same process. The reason for the "
                f"first call was '{self._poly_hvd_first_reason}'; the reason for this call is "
                f"'{reason}'.",
            )
        else:
            self._poly_hvd_type = horovod_type
            self._poly_hvd_first_reason = reason
            # If horovod has not been imported yet, do it now.
            try:
                self._poly_hvd_module = importlib.import_module(
                    f"horovod.{horovod_type}")
            except ImportError:
                pass
Пример #5
0
    def recv(self) -> Any:

        obj = self._sub_socket.recv_pyobj()

        if isinstance(obj, _SerialMessage):
            check.eq(obj.serial, self._recv_serial, "Out-of-order server message detected")
            self._recv_serial += 1
            return obj.payload
        raise AssertionError(f"Unexpected message type encountered: {type(obj)}")
Пример #6
0
def validate_batch_metrics(batch_metrics: List[Dict[str, Any]]) -> None:
    metric_dict = _list_to_dict(batch_metrics)

    # We expect that all batches have the same set of metrics.
    metric_dict_keys = metric_dict.keys()
    for idx, metric_dict in zip(range(len(batch_metrics)), batch_metrics):
        keys = metric_dict.keys()
        if metric_dict_keys == keys:
            continue

        check.eq(metric_dict_keys, keys, "inconsistent training metrics: index: {}".format(idx))
Пример #7
0
    def _recv_one(self) -> Tuple[Any, type]:
        """
        Receive one _SerialMessage from the socket and confirm that it is in-order.
        """

        obj = self._pull_socket.recv_pyobj()

        if isinstance(obj, _ExceptionMessage):
            return None, _ExceptionMessage

        if isinstance(obj, _SerialMessage):
            check.eq(obj.serial, self._recv_serial, "Out-of-order client message detected")
            return obj.payload, _SerialMessage

        raise AssertionError(f"Unexpected message type encountered: {type(obj)}")
def binary_error_rate(predictions: torch.Tensor,
                      labels: torch.Tensor) -> float:
    """Return the classification error rate for binary classification."""
    check.eq(predictions.shape[0], labels.shape[0])
    check.is_in(len(predictions.shape), [1, 2])
    if len(predictions.shape) == 2:
        check.eq(predictions.shape[1], 1)
    check.len_eq(labels.shape, 1, "Labels must be a column vector")

    if len(predictions.shape) > 1:
        predictions = torch.squeeze(predictions)

    errors = torch.sum(
        labels.to(torch.long) != torch.round(predictions).to(torch.long))
    result = float(errors) / predictions.shape[0]  # type: float
    return result
Пример #9
0
def _make_test_workloads(config: det.ExperimentConfig) -> workload.Stream:
    interceptor = workload.WorkloadResponseInterceptor()

    logging.info("Training one batch")
    yield from interceptor.send(workload.train_workload(1))
    metrics = interceptor.metrics_result()
    batch_metrics = metrics["metrics"]["batch_metrics"]
    check.eq(len(batch_metrics), config.scheduling_unit())
    logging.info(f"Finished training, metrics: {batch_metrics}")

    logging.info("Validating one batch")
    yield from interceptor.send(workload.validation_workload(1))
    validation = interceptor.metrics_result()
    v_metrics = validation["metrics"]["validation_metrics"]
    logging.info(f"Finished validating, validation metrics: {v_metrics}")

    logging.info("Saving a checkpoint.")
    yield workload.checkpoint_workload(), workload.ignore_workload_response
    logging.info("Finished saving a checkpoint.")
Пример #10
0
    def __init__(
        self,
        num_connections: Optional[int] = None,
        ports: Optional[List[int]] = None,
        port_range: Optional[Tuple[int, int]] = None,
    ) -> None:
        self.context = zmq.Context()  # type: ignore
        self.sockets = []  # type: List[zmq.Socket]
        self.ports = []  # type: List[int]

        if ports:
            check.is_none(port_range)
            self._bind_to_specified_ports(ports=ports)
            check.eq(len(self.ports), len(ports))
        else:
            check.is_not_none(num_connections)
            check.is_not_none(port_range)
            num_connections = cast(int, num_connections)
            port_range = cast(Tuple[int, int], port_range)
            self._bind_to_random_ports(port_range=port_range,
                                       num_connections=num_connections)
            check.eq(len(self.ports), num_connections)
Пример #11
0
    def _prepare_metrics_reducers(self, keys: Any) -> Dict[str, pytorch.Reducer]:
        metrics_reducers = {}  # type: Dict[str, pytorch.Reducer]
        reducer = self.trial.evaluation_reducer()
        if isinstance(reducer, Dict):
            metrics_reducers = reducer
            check.eq(
                metrics_reducers.keys(),
                keys,
                "Please provide a single evaluation reducer or "
                "provide a reducer for every validation metric. "
                f"Expected keys: {keys}, provided keys: {metrics_reducers.keys()}.",
            )
        elif isinstance(reducer, pytorch.Reducer):
            for key in keys:
                metrics_reducers[key] = reducer

        for key in keys:
            check.true(
                isinstance(metrics_reducers[key], pytorch.Reducer),
                "Please select `determined.pytorch.Reducer` for reducing validation metrics.",
            )

        return metrics_reducers
Пример #12
0
    def __init__(
        self,
        x: ArrayLike,
        y: ArrayLike,
        batch_size: int,
        sample_weights: Optional[np.ndarray] = None,
        drop_leftovers: bool = False,
    ):
        """
        If converting numpy array data to Sequence to optimize performance, consider
        using ArrayLikeAdapter.

        Args:
            x: Input data. It could be:
                1) A Numpy array (or array-like), or a list of arrays (in case the model
                has multiple inputs).
                2) A dict mapping input names to the corresponding array, if the model
                has named inputs.

            y: Target data. Like the input data x, it could be either Numpy array(s).

            batch_size: Number of samples per batch.

            sample_weights: Numpy array of weights for the samples.

            drop_leftovers: If True, drop the data that cannot complete the last batch. This
                argument is ignored if x is a Sequence or a Dataset.
        """

        if not (isinstance(x, np.ndarray) or _is_list_of_numpy_array(x)
                or _is_dict_of_numpy_array(x)):
            raise det.errors.InvalidDataTypeException(
                type(x),
                "Data which is not tf.data.Datasets or tf.keras.utils.Sequence objects must be a "
                "numpy array or a list/dict of numpy arrays. See the instructions below for "
                f"details:\n{keras.TFKerasTrial.build_training_data_loader.__doc__}",
            )
        if not (isinstance(y, np.ndarray) or _is_list_of_numpy_array(y)
                or _is_dict_of_numpy_array(y)):
            raise det.errors.InvalidDataTypeException(
                type(y),
                "Data which is not tf.data.Datasets or tf.keras.utils.Sequence objects must be a "
                "numpy array or a list/dict of numpy arrays. See the instructions below for "
                f"details:\n{keras.TFKerasTrial.build_training_data_loader.__doc__}",
            )

        self._x_length = _length_of_multi_arraylike(x)
        self._y_length = _length_of_multi_arraylike(y)

        check.eq(self._x_length, self._y_length,
                 "Length of x and y do not match.")
        check.check_gt_eq(self._x_length, batch_size,
                          "Batch size is too large for the input data.")
        if sample_weights is not None:
            check.eq(
                self._x_length,
                len(sample_weights),
                "Lengths of input data and sample weights do not match.",
            )

        self.x = x
        self.y = y
        self.sample_weight = sample_weights

        self.batch_size = batch_size
        self.drop_leftovers = drop_leftovers
Пример #13
0
    def configure_apex_amp(
        self,
        models: Union[torch.nn.Module, List[torch.nn.Module]],
        optimizers: Union[torch.optim.Optimizer, List[torch.optim.Optimizer]],
        enabled: Optional[bool] = True,
        opt_level: Optional[str] = "O1",
        cast_model_type: Optional[torch.dtype] = None,
        patch_torch_functions: Optional[bool] = None,
        keep_batchnorm_fp32: Optional[Union[bool, str]] = None,
        master_weights: Optional[bool] = None,
        loss_scale: Optional[Union[float, str]] = None,
        cast_model_outputs: Optional[torch.dtype] = None,
        num_losses: Optional[int] = 1,
        verbosity: Optional[int] = 1,
        min_loss_scale: Optional[float] = None,
        max_loss_scale: Optional[float] = 2.0**24,
    ) -> Tuple:
        """
        Configure automatic mixed precision for your models and optimizers using NVIDIA's Apex
        PyTorch extension. Note that details for apex.amp are handled automatically within
        Determined after this call.

        This function must be called **after** you have finished constructing your models and
        optimizers with :meth:`wrap_model` and :meth:`wrap_optimizer`.

        This function has the same arguments as
        `apex.amp.initialize <https://nvidia.github.io/apex/amp.html#apex.amp.initialize>`_.

        .. warning::
            When using distributed training and automatic mixed precision,
            we only support ``num_losses=1`` and calling backward on the loss once.

        Arguments:
            models (``torch.nn.Module`` or list of ``torch.nn.Module`` s):  Model(s) to modify/cast.
            optimizers (``torch.optim.Optimizer`` or list of ``torch.optim.Optimizer`` s):
                Optimizers to modify/cast. REQUIRED for training.
            enabled (bool, optional, default=True):  If False, renders all Amp calls no-ops,
                so your script should run as if Amp were not present.
            opt_level (str, optional, default="O1"):  Pure or mixed precision optimization level.
                Accepted values are "O0", "O1", "O2", and "O3", explained in detail above.
            cast_model_type (``torch.dtype``, optional, default=None):  Optional property override,
                see above.
            patch_torch_functions (bool, optional, default=None):  Optional property override.
            keep_batchnorm_fp32 (bool or str, optional, default=None):  Optional property override.
                If passed as a string, must be the string "True" or "False".
            master_weights (bool, optional, default=None):  Optional property override.
            loss_scale (float or str, optional, default=None):  Optional property override.
                If passed as a string, must be a string representing a number, e.g., "128.0",
                or the string "dynamic".
            cast_model_outputs (torch.dtype, optional, default=None):  Option to ensure that
                the outputs of your model is always cast to a particular type regardless of
                ``opt_level``.
            num_losses (int, optional, default=1):  Option to tell Amp in advance how many
                losses/backward passes you plan to use.  When used in conjunction with the
                ``loss_id`` argument to ``amp.scale_loss``, enables Amp to use a different
                loss scale per loss/backward pass, which can improve stability.
                If ``num_losses`` is left to 1, Amp will still support multiple losses/backward
                passes, but use a single global loss scale for all of them.
            verbosity (int, default=1):  Set to 0 to suppress Amp-related output.
            min_loss_scale (float, default=None):  Sets a floor for the loss scale values that
                can be chosen by dynamic loss scaling.  The default value of None means that no
                floor is imposed. If dynamic loss scaling is not used, `min_loss_scale` is ignored.
            max_loss_scale (float, default=2.**24):  Sets a ceiling for the loss scale values
                that can be chosen by dynamic loss scaling.  If dynamic loss scaling is not used,
                `max_loss_scale` is ignored.

        Returns:
            Model(s) and optimizer(s) modified according to the ``opt_level``.
            If  ``optimizers`` args were lists, the corresponding return value will
            also be a list.
        """
        if not self.env.managed_training:
            return models, optimizers

        check.is_none(self._scaler, "Do not mix APEX with PyTorch AMP")

        check.false(self._use_apex,
                    "Please only call configure_apex_amp once.")
        if self.distributed.size > 1:
            check.eq(
                num_losses,
                1,
                "When using parallel/distributed training, "
                "Determined only supports configure_apex_amp with num_losses = 1",
            )

        self._use_apex = True

        if self.distributed.size > 1:
            check.eq(
                self._aggregation_frequency,
                1,
                "Mixed precision training (AMP) is not supported with "
                "aggregation frequency > 1.",
            )

        check.true(
            torch.cuda.is_available(),
            "Mixed precision training (AMP) is supported only on GPU slots.",
        )

        if self._distributed_backend.use_torch():
            # We need to get the pre-wrapped input models to initialize APEX because
            if isinstance(models, list):
                models = [
                    self._wrapped_models[wrapped_model]
                    for wrapped_model in models
                ]
            else:
                models = self._wrapped_models[models]

        logging.info(
            f"Enabling mixed precision training with opt_level: {opt_level}.")
        models, optimizers = apex.amp.initialize(
            models=models,
            optimizers=optimizers,
            enabled=enabled,
            opt_level=opt_level,
            cast_model_type=cast_model_type,
            patch_torch_functions=patch_torch_functions,
            keep_batchnorm_fp32=keep_batchnorm_fp32,
            master_weights=master_weights,
            loss_scale=loss_scale,
            cast_model_outputs=cast_model_outputs,
            num_losses=num_losses,
            min_loss_scale=min_loss_scale,
            max_loss_scale=max_loss_scale,
            verbosity=verbosity if self.distributed.get_rank() == 0
            or self.env.experiment_config.debug_enabled() else 0,
        )

        if not isinstance(models, list):
            self.models = [models]

        if self.distributed.size > 1 and self._distributed_backend.use_torch():
            # If Torch DDP is in use, re-wrap the models
            self.models = [
                self._PyTorchDistributedDataParallel(model)
                for model in self.models
            ]

        if not isinstance(optimizers, list):
            self.optimizers = [optimizers]
        return models, optimizers
Пример #14
0
    def select_checkpoint(
        self,
        latest: bool = False,
        best: bool = False,
        uuid: Optional[str] = None,
        sort_by: Optional[str] = None,
        smaller_is_better: Optional[bool] = None,
    ) -> checkpoint.Checkpoint:
        """
        Return the :class:`~determined.experimental.Checkpoint` instance with the best
        validation metric as defined by the ``sort_by`` and ``smaller_is_better``
        arguments.

        Exactly one of the ``best``, ``latest``, or ``uuid`` parameters must be set.

        Arguments:
            latest (bool, optional): Return the most recent checkpoint.

            best (bool, optional): Return the checkpoint with the best validation
                metric as defined by the ``sort_by`` and ``smaller_is_better``
                arguments. If ``sort_by`` and ``smaller_is_better`` are not
                specified, the values from the associated experiment
                configuration will be used.

            uuid (string, optional): Return the checkpoint for the specified UUID.

            sort_by (string, optional): The name of the validation metric to
                order checkpoints by. If this parameter is unset the metric defined
                in the related experiment configuration searcher field will be
                used.

            smaller_is_better (bool, optional): Whether to sort the
                metric above in ascending or descending order. If ``sort_by`` is unset,
                this parameter is ignored. By default, the value of ``smaller_is_better``
                from the experiment's configuration is used.
        """
        check.eq(
            sum([int(latest), int(best),
                 int(uuid is not None)]),
            1,
            "Exactly one of latest, best, or uuid must be set",
        )

        check.eq(
            sort_by is None,
            smaller_is_better is None,
            "sort_by and smaller_is_better must be set together",
        )

        if sort_by is not None and not best:
            raise AssertionError(
                "`sort_by` and `smaller_is_better` parameters can only be used with `best`"
            )

        if uuid:
            resp = api.get(self._master, "/api/v1/checkpoints/{}".format(uuid))
            return checkpoint.Checkpoint.from_json(resp.json()["checkpoint"],
                                                   master=self._master)

        r = api.get(
            self._master,
            "/api/v1/trials/{}/checkpoints".format(self.id),
            # The default sort order from the API is by batch number. The order
            # by parameter indicates descending order.
            params={
                "order_by": 2
            },
        ).json()
        checkpoints = r["checkpoints"]

        if not checkpoints:
            raise AssertionError("No checkpoint found for trial {}".format(
                self.id))

        if latest:
            return checkpoint.Checkpoint.from_json(checkpoints[0],
                                                   master=self._master)

        if not sort_by:
            sort_by = checkpoints[0]["experimentConfig"]["searcher"]["metric"]
            smaller_is_better = checkpoints[0]["experimentConfig"]["searcher"][
                "smaller_is_better"]

        best_checkpoint_func = min if smaller_is_better else max
        return checkpoint.Checkpoint.from_json(
            best_checkpoint_func(
                [c for c in checkpoints if c["metrics"] is not None],
                key=lambda x: x["metrics"]["validationMetrics"][sort_by],
            ),
            master=self._master,
        )
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        step_start_time = time.time()

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_start,
                                  pytorch.PyTorchCallback):
                logging.warning("on_validation_step_start is now deprecated, "
                                "please use on_validation_start instead")
                callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for callback in self.callbacks.values():
                callback.on_validation_epoch_start()
            for idx, batch in enumerate(self.validation_loader):
                if self.context.experimental._auto_to_device:
                    batch = self.context.to_device(batch)
                num_inputs += self.trial.get_batch_length(batch)

                if has_param(self.trial.evaluate_batch, "batch_idx", 2):
                    vld_metrics = self.trial.evaluate_batch(batch=batch,
                                                            batch_idx=idx)
                else:
                    vld_metrics = self.trial.evaluate_batch(
                        batch=batch)  # type: ignore
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    pytorch._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            for callback in self.callbacks.values():
                callback.on_validation_epoch_end(batch_metrics)

            metrics = pytorch._reduce_metrics(
                self.context.distributed,
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=pytorch._prepare_metrics_reducers(
                    self.trial.evaluation_reducer(), keys=keys),
            )

            # Gather a list of per-worker (num_inputs, num_batches) tuples.
            input_counts = self.context.distributed.gather(
                (num_inputs, idx + 1))
            if self.context.distributed.rank == 0:
                assert input_counts is not None
                # Reshape and sum.
                num_inputs, num_batches = [sum(n) for n in zip(*input_counts)]

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = pytorch._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            pytorch._convert_metrics_to_numpy(
                self.context.reduce_metrics(for_training=False)))

        if self.context.distributed.size > 1 and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_end,
                                  pytorch.PyTorchCallback):
                logging.warning(
                    "on_validation_step_end is now deprecated, please use on_validation_end instead"
                )
                callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return {}

        # Skip reporting timings if evaluate_full_dataset() was defined.  This is far less common
        # than evaluate_batch() and we can't know how the user processed their validation data.
        if self._evaluate_batch_defined():
            step_duration = time.time() - step_start_time
            logging.info(
                det.util.make_timing_log("validated", step_duration,
                                         num_inputs, num_batches))

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
Пример #16
0
    def _load(self) -> None:
        if not self.load_path:
            return

        # Backwards compat with older checkpoint formats. List is newest to
        # oldest known state_dict locations.
        potential_paths = [
            ["state_dict.pth"],
            ["determined", "state_dict.pth"],
            ["pedl", "state_dict.pth"],
            ["checkpoint.pt"],
        ]

        checkpoint: Optional[Dict[str, Any]] = None
        for ckpt_path in potential_paths:
            maybe_ckpt = self.load_path.joinpath(*ckpt_path)
            if maybe_ckpt.exists():
                checkpoint = torch.load(str(maybe_ckpt),
                                        map_location="cpu")  # type: ignore
                break
        if checkpoint is None or not isinstance(checkpoint, dict):
            return

        for callback in self.callbacks.values():
            callback.on_checkpoint_load_start(checkpoint)

        if "model_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("models_state_dict", checkpoint)
            check.eq(len(self.context.models), 1)
            self.context.models[0].load_state_dict(
                checkpoint["model_state_dict"])
        else:
            for idx, model in enumerate(self.context.models):
                model.load_state_dict(checkpoint["models_state_dict"][idx])

        if "optimizer_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("optimizers_state_dict", checkpoint)
            check.eq(len(self.context.optimizers), 1)
            self.context.optimizers[0].load_state_dict(
                checkpoint["optimizer_state_dict"])
        else:
            for idx, optimizer in enumerate(self.context.optimizers):
                optimizer.load_state_dict(
                    checkpoint["optimizers_state_dict"][idx])

        if "lr_scheduler" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("lr_schedulers_state_dict", checkpoint)
            check.eq(len(self.context.lr_schedulers), 1)
            self.context.lr_schedulers[0].load_state_dict(
                checkpoint["lr_scheduler"])
        else:
            for idx, lr_scheduler in enumerate(self.context.lr_schedulers):
                lr_scheduler.load_state_dict(
                    checkpoint["lr_schedulers_state_dict"][idx])

        if "scaler_state_dict":
            if self.context._scaler:
                self.context._scaler.load_state_dict(
                    checkpoint["scaler_state_dict"])
            else:
                logging.warning(
                    "There exists scaler_state_dict in checkpoint but the experiment is not using "
                    "AMP.")
        else:
            if self.context._scaler:
                logging.warning(
                    "The experiment is using AMP but scaler_state_dict does not exist in the "
                    "checkpoint.")

        if "amp_state" in checkpoint:
            if self.context._use_apex:
                apex.amp.load_state_dict(checkpoint["amp_state"])
            else:
                logging.warning(
                    "There exists amp_state in checkpoint but the experiment is not using Apex."
                )
        else:
            if self.context._use_apex:
                logging.warning(
                    "The experiment is using Apex but amp_state does not exist in the checkpoint."
                )

        if "rng_state" in checkpoint:
            rng_state = checkpoint["rng_state"]
            np.random.set_state(rng_state["np_rng_state"])
            random.setstate(rng_state["random_rng_state"])
            torch.random.set_rng_state(rng_state["cpu_rng_state"])

            if torch.cuda.device_count():
                if "gpu_rng_state" in rng_state:
                    torch.cuda.set_rng_state(
                        rng_state["gpu_rng_state"],
                        device=self.context.distributed.get_local_rank())
                else:
                    logging.warning(
                        "The system has a gpu but no gpu_rng_state exists in the checkpoint."
                    )
            else:
                if "gpu_rng_state" in rng_state:
                    logging.warning(
                        "There exists gpu_rng_state in checkpoint but the system has no gpu."
                    )
        else:
            logging.warning("The checkpoint has no random state to restore.")

        callback_state = checkpoint.get("callbacks", {})
        for name in self.callbacks:
            if name in callback_state:
                self.callbacks[name].load_state_dict(callback_state[name])
            elif util.is_overridden(self.callbacks[name].load_state_dict,
                                    pytorch.PyTorchCallback):
                logging.warning(
                    "Callback '{}' implements load_state_dict(), but no callback state "
                    "was found for that name when restoring from checkpoint. This "
                    "callback will be initialized from scratch")
Пример #17
0
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_start,
                                  pytorch.PyTorchCallback):
                logging.warning("on_validation_step_start is now deprecated, "
                                "please use on_validation_start instead")
                callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for callback in self.callbacks.values():
                callback.on_validation_epoch_start()
            for idx, batch in enumerate(self.validation_loader):
                batch = self.context.to_device(batch)
                num_inputs += self.trial.get_batch_length(batch)

                if has_param(self.trial.evaluate_batch, "batch_idx", 2):
                    vld_metrics = self.trial.evaluate_batch(batch=batch,
                                                            batch_idx=idx)
                else:
                    vld_metrics = self.trial.evaluate_batch(
                        batch=batch)  # type: ignore
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    self._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            for callback in self.callbacks.values():
                callback.on_validation_epoch_end(batch_metrics)

            metrics = self._reduce_metrics(
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=self._prepare_metrics_reducers(keys=keys),
            )

            if self.hvd_config.use:
                num_inputs *= hvd.size()

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = self._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            self._convert_metrics_to_numpy(
                self.context.reduce_metrics(for_training=False)))

        if self.hvd_config.use and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_end,
                                  pytorch.PyTorchCallback):
                logging.warning(
                    "on_validation_step_end is now deprecated, please use on_validation_end instead"
                )
                callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return workload.Skipped()

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
    def _load(self, load_path: pathlib.Path) -> None:
        # Backwards compat with older checkpoint formats. List is newest to
        # oldest known state_dict locations.
        potential_paths = [
            ["state_dict.pth"],
            ["determined", "state_dict.pth"],
            ["pedl", "state_dict.pth"],
            ["checkpoint.pt"],
        ]

        checkpoint: Optional[Dict[str, Any]] = None
        for ckpt_path in potential_paths:
            maybe_ckpt = load_path.joinpath(*ckpt_path)
            if maybe_ckpt.exists():
                checkpoint = torch.load(str(maybe_ckpt),
                                        map_location="cpu")  # type: ignore
                break
        if checkpoint is None or not isinstance(checkpoint, dict):
            return

        for callback in self.callbacks.values():
            callback.on_checkpoint_load_start(checkpoint)

        if "model_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("models_state_dict", checkpoint)
            check.eq(len(self.context.models), 1)
            self.context.models[0].load_state_dict(
                checkpoint["model_state_dict"])
        else:
            for idx, model in enumerate(self.context.models):
                model_state_dict = checkpoint["models_state_dict"][idx]
                try:
                    model.load_state_dict(model_state_dict)
                except Exception:
                    # If the checkpointed model is non-DDP and the current model is DDP, append
                    # module prefix to the checkpointed data
                    if isinstance(model,
                                  torch.nn.parallel.DistributedDataParallel):
                        logging.debug(
                            "Loading non-DDP checkpoint into a DDP model")
                        self._add_prefix_in_state_dict_if_not_present(
                            model_state_dict, "module.")
                    else:
                        # If the checkpointed model is DDP and we are currently running in
                        # single-slot mode, remove the module prefix from checkpointed data
                        logging.debug(
                            "Loading DDP checkpoint into a non-DDP model")
                        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
                            model_state_dict, "module.")
                    model.load_state_dict(model_state_dict)

        if "optimizer_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("optimizers_state_dict", checkpoint)
            check.eq(len(self.context.optimizers), 1)
            self.context.optimizers[0].load_state_dict(
                checkpoint["optimizer_state_dict"])
        else:
            for idx, optimizer in enumerate(self.context.optimizers):
                optimizer.load_state_dict(
                    checkpoint["optimizers_state_dict"][idx])

        if "lr_scheduler" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("lr_schedulers_state_dict", checkpoint)
            check.eq(len(self.context.lr_schedulers), 1)
            self.context.lr_schedulers[0].load_state_dict(
                checkpoint["lr_scheduler"])
        else:
            for idx, lr_scheduler in enumerate(self.context.lr_schedulers):
                lr_scheduler.load_state_dict(
                    checkpoint["lr_schedulers_state_dict"][idx])

        if "scaler_state_dict" in checkpoint:
            if self.context._scaler:
                self.context._scaler.load_state_dict(
                    checkpoint["scaler_state_dict"])
            else:
                logging.warning(
                    "There exists scaler_state_dict in checkpoint but the experiment is not using "
                    "AMP.")
        else:
            if self.context._scaler:
                logging.warning(
                    "The experiment is using AMP but scaler_state_dict does not exist in the "
                    "checkpoint.")

        if "amp_state" in checkpoint:
            if self.context._use_apex:
                apex.amp.load_state_dict(checkpoint["amp_state"])
            else:
                logging.warning(
                    "There exists amp_state in checkpoint but the experiment is not using Apex."
                )
        else:
            if self.context._use_apex:
                logging.warning(
                    "The experiment is using Apex but amp_state does not exist in the checkpoint."
                )

        if "rng_state" in checkpoint:
            rng_state = checkpoint["rng_state"]
            np.random.set_state(rng_state["np_rng_state"])
            random.setstate(rng_state["random_rng_state"])
            torch.random.set_rng_state(rng_state["cpu_rng_state"])

            if torch.cuda.device_count():
                if "gpu_rng_state" in rng_state:
                    torch.cuda.set_rng_state(
                        rng_state["gpu_rng_state"],
                        device=self.context.distributed.local_rank)
                else:
                    logging.warning(
                        "The system has a gpu but no gpu_rng_state exists in the checkpoint."
                    )
            else:
                if "gpu_rng_state" in rng_state:
                    logging.warning(
                        "There exists gpu_rng_state in checkpoint but the system has no gpu."
                    )
        else:
            logging.warning("The checkpoint has no random state to restore.")

        callback_state = checkpoint.get("callbacks", {})
        for name in self.callbacks:
            if name in callback_state:
                self.callbacks[name].load_state_dict(callback_state[name])
            elif util.is_overridden(self.callbacks[name].load_state_dict,
                                    pytorch.PyTorchCallback):
                logging.warning(
                    "Callback '{}' implements load_state_dict(), but no callback state "
                    "was found for that name when restoring from checkpoint. This "
                    "callback will be initialized from scratch")

        # Load workload sequencer state.
        wlsq_path = load_path.joinpath("workload_sequencer.pkl")
        if self.wlsq is not None and wlsq_path.exists():
            with wlsq_path.open("rb") as f:
                self.wlsq.load_state(pickle.load(f))