Exemple #1
0
        def _respond(in_response: workload.Response) -> None:
            # Only the chief container should actually respond to TRAIN_FOR_STEP.
            if self.rendezvous_info.get_rank() != 0:
                respond(workload.Skipped())
                return

            check_not_isinstance(in_response, workload.Skipped,
                                 "Chief skipped a workload.")

            in_response = cast(workload.Metrics, in_response)
            metrics = in_response["metrics"]

            metrics = cast(workload.Metrics, metrics)

            if in_response.get("invalid_hp", False):
                out_response = {
                    "type": "WORKLOAD_COMPLETED",
                    "workload": wkld,
                    "start_time": start_time,
                    "end_time": _current_timestamp(),
                    "metrics": metrics,
                }
                out_response["exited_reason"] = "INVALID_HP"
                respond(out_response)
                return

            batch_metrics = metrics["batch_metrics"]

            # Sanity-check training metrics.
            det.util.validate_batch_metrics(batch_metrics)
            check_len(batch_metrics, wkld.num_batches)

            for callback in self.callbacks:
                callback.on_train_step_end(wkld.step_id, wkld.num_batches,
                                           wkld.total_batches_processed,
                                           metrics)

            self.tensorboard_mgr.sync()

            out_response = {
                "type": "WORKLOAD_COMPLETED",
                "workload": wkld,
                "start_time": start_time,
                "end_time": _current_timestamp(),
                "metrics": metrics,
            }

            if in_response.get("stop_requested", False):
                out_response["exited_reason"] = "USER_CANCELED"

            # Send the response up.
            respond(out_response)
        def _respond(checkpoint_info: workload.Response) -> None:
            checkpoint_info = cast(Dict[str, Any], checkpoint_info)
            metadata = storage.StorageMetadata(
                storage_id,
                storage.StorageManager._list_directory(path),
                checkpoint_info.get("framework", ""),
                checkpoint_info.get("format", ""),
            )

            logging.info("Saved trial to checkpoint {}".format(metadata.storage_id))
            self.tensorboard_mgr.sync()

            nonlocal message
            message = {
                "type": "WORKLOAD_COMPLETED",
                "workload": wkld,
                "start_time": start_time,
                "end_time": _current_timestamp(),
                "metrics": metadata,
            }
        def _respond(in_response: workload.Response) -> None:

            # Only the chief container should actually respond to COMPUTE_VALIDATION_METRICS.
            if self.rendezvous_info.get_rank() != 0:
                respond(workload.Skipped())
                return

            check_not_isinstance(in_response, workload.Skipped,
                                 "Chief skipped a workload.")
            in_response = cast(Dict[str, Any], in_response)
            metrics = in_response["metrics"]
            metrics = cast(workload.Metrics, metrics)

            v_metrics = metrics["validation_metrics"]
            for callback in self.callbacks:
                callback.on_validation_step_end(wkld.step_id,
                                                wkld.total_batches_processed,
                                                v_metrics)

            self.tensorboard_mgr.sync()

            # Check that the validation metrics computed by the model code
            # includes the metric used by the search method.
            searcher_metric = self.env.experiment_config["searcher"]["metric"]
            if searcher_metric not in v_metrics:
                raise AssertionError(
                    "Search method is configured to use metric '{}' but model "
                    "definition returned validation metrics {}. The metric "
                    "used by the search method must be one of the validation "
                    "metrics returned by the model definition.".format(
                        searcher_metric, list(v_metrics.keys())))
                sys.exit(1)

            non_serializable_metrics = set()
            # NaN and bytes are not JSON serializable. None does not have a
            # canonical JSON representation. In the case of trial implementation bugs
            # or numerical instability issues, validation metric functions may
            # return None or NaN values. For now, immediately fail any trial that
            # encounters such a None metric. For NaN metrics, if it's the target of
            # the searcher, we set it to +/- max_float depending on if the searcher
            # is optimizing for the max or min. NaN metrics which are not the
            # target of the searcher are dropped.
            # TODO (DET-2495): Do not replace NaN metric values.
            for metric_name, metric_value in v_metrics.items():
                metric_is_none = metric_value is None
                metric_is_nan = tensorboard.metric_writers.util.is_numerical_scalar(
                    metric_value) and math.isnan(metric_value)

                if metric_is_none or metric_is_nan:
                    raise AssertionError("Validation metric '{}' returned "
                                         "an invalid scalar value: {}".format(
                                             metric_name, metric_value))
                    sys.exit(1)

                if isinstance(metric_value, (bytes, bytearray)):
                    non_serializable_metrics.add(metric_name)

            if len(non_serializable_metrics):
                logging.warning("Removed non serializable metrics: %s",
                                ", ".join(non_serializable_metrics))
                for metric_name in non_serializable_metrics:
                    del v_metrics[metric_name]

            out_response = {
                "type": "WORKLOAD_COMPLETED",
                "workload": wkld,
                "start_time": start_time,
                "end_time": _current_timestamp(),
                "metrics": metrics,
            }

            if in_response.get("stop_requested", False):
                out_response["exited_reason"] = "USER_CANCELED"

            respond(out_response)