def _respond(in_response: workload.Response) -> None: # Only the chief container should actually respond to TRAIN_FOR_STEP. if self.rendezvous_info.get_rank() != 0: respond(workload.Skipped()) return check_not_isinstance(in_response, workload.Skipped, "Chief skipped a workload.") in_response = cast(workload.Metrics, in_response) metrics = in_response["metrics"] metrics = cast(workload.Metrics, metrics) if in_response.get("invalid_hp", False): out_response = { "type": "WORKLOAD_COMPLETED", "workload": wkld, "start_time": start_time, "end_time": _current_timestamp(), "metrics": metrics, } out_response["exited_reason"] = "INVALID_HP" respond(out_response) return batch_metrics = metrics["batch_metrics"] # Sanity-check training metrics. det.util.validate_batch_metrics(batch_metrics) check_len(batch_metrics, wkld.num_batches) for callback in self.callbacks: callback.on_train_step_end(wkld.step_id, wkld.num_batches, wkld.total_batches_processed, metrics) self.tensorboard_mgr.sync() out_response = { "type": "WORKLOAD_COMPLETED", "workload": wkld, "start_time": start_time, "end_time": _current_timestamp(), "metrics": metrics, } if in_response.get("stop_requested", False): out_response["exited_reason"] = "USER_CANCELED" # Send the response up. respond(out_response)
def _respond(checkpoint_info: workload.Response) -> None: checkpoint_info = cast(Dict[str, Any], checkpoint_info) metadata = storage.StorageMetadata( storage_id, storage.StorageManager._list_directory(path), checkpoint_info.get("framework", ""), checkpoint_info.get("format", ""), ) logging.info("Saved trial to checkpoint {}".format(metadata.storage_id)) self.tensorboard_mgr.sync() nonlocal message message = { "type": "WORKLOAD_COMPLETED", "workload": wkld, "start_time": start_time, "end_time": _current_timestamp(), "metrics": metadata, }
def _respond(in_response: workload.Response) -> None: # Only the chief container should actually respond to COMPUTE_VALIDATION_METRICS. if self.rendezvous_info.get_rank() != 0: respond(workload.Skipped()) return check_not_isinstance(in_response, workload.Skipped, "Chief skipped a workload.") in_response = cast(Dict[str, Any], in_response) metrics = in_response["metrics"] metrics = cast(workload.Metrics, metrics) v_metrics = metrics["validation_metrics"] for callback in self.callbacks: callback.on_validation_step_end(wkld.step_id, wkld.total_batches_processed, v_metrics) self.tensorboard_mgr.sync() # Check that the validation metrics computed by the model code # includes the metric used by the search method. searcher_metric = self.env.experiment_config["searcher"]["metric"] if searcher_metric not in v_metrics: raise AssertionError( "Search method is configured to use metric '{}' but model " "definition returned validation metrics {}. The metric " "used by the search method must be one of the validation " "metrics returned by the model definition.".format( searcher_metric, list(v_metrics.keys()))) sys.exit(1) non_serializable_metrics = set() # NaN and bytes are not JSON serializable. None does not have a # canonical JSON representation. In the case of trial implementation bugs # or numerical instability issues, validation metric functions may # return None or NaN values. For now, immediately fail any trial that # encounters such a None metric. For NaN metrics, if it's the target of # the searcher, we set it to +/- max_float depending on if the searcher # is optimizing for the max or min. NaN metrics which are not the # target of the searcher are dropped. # TODO (DET-2495): Do not replace NaN metric values. for metric_name, metric_value in v_metrics.items(): metric_is_none = metric_value is None metric_is_nan = tensorboard.metric_writers.util.is_numerical_scalar( metric_value) and math.isnan(metric_value) if metric_is_none or metric_is_nan: raise AssertionError("Validation metric '{}' returned " "an invalid scalar value: {}".format( metric_name, metric_value)) sys.exit(1) if isinstance(metric_value, (bytes, bytearray)): non_serializable_metrics.add(metric_name) if len(non_serializable_metrics): logging.warning("Removed non serializable metrics: %s", ", ".join(non_serializable_metrics)) for metric_name in non_serializable_metrics: del v_metrics[metric_name] out_response = { "type": "WORKLOAD_COMPLETED", "workload": wkld, "start_time": start_time, "end_time": _current_timestamp(), "metrics": metrics, } if in_response.get("stop_requested", False): out_response["exited_reason"] = "USER_CANCELED" respond(out_response)