def new_train_batch(batch: pytorch.TorchData, epoch_idx: int, batch_idx: int) -> Any: tr_metrics = train_batch( batch=batch, model=model, epoch_idx=epoch_idx, batch_idx=batch_idx, ) if isinstance(tr_metrics, torch.Tensor): tr_metrics = {"loss": tr_metrics} check.is_instance( tr_metrics, dict, "train_batch() must return a dictionary " f"mapping string names to Tensor metrics, got {type(tr_metrics)}", ) check.is_in("loss", tr_metrics.keys(), 'Please include "loss" in you training metrics.') def clip_grads(parameters: Iterator) -> None: for callback in self.callbacks.values(): callback.on_before_optimizer_step(parameters) self.context.backward(tr_metrics["loss"]) self.context.step_optimizer(self.context.optimizers[0], clip_grads=clip_grads) return tr_metrics
def metrics_result(self) -> Metrics: """Identical to result but disallow workload.Skipped responses.""" check.is_not_none(self._response, "_respond() was not called by the TrialController.") check.is_instance(self._response, dict, "unexpected SkippedWorkload response.") return cast(Metrics, self._response)
def run(self) -> None: for w, args, response_func in self.workloads: if w.kind == workload.Workload.Kind.RUN_STEP: response_func( util.wrap_metrics( self._train_for_step(w.step_id, w.num_batches, w.total_batches_processed), self.context.get_stop_requested(), ) ) elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( util.wrap_metrics( self._compute_validation_metrics(), self.context.get_stop_requested() ) ) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.eq(len(args), 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save(path)) elif w.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) break else: raise AssertionError("Unexpected workload: {}".format(w.kind))
def _calculate_batch_sizes(self) -> Tuple[int, int]: if "global_batch_size" not in self.hparams.keys(): raise AssertionError( "Please specify `global_batch_size` under `hyperparameters` " "in experiment config.") if "batch_size" in self.hparams.keys(): logging.warning( "Use `global_batch_size` not `batch_size` under `hyperparameters` " "in experiment config.") global_batch_size = self.hparams["global_batch_size"] check.is_instance(global_batch_size, int, "`global_batch_size` hparam must be an int.") global_batch_size = cast(int, global_batch_size) if self.experiment_config.native_parallel_enabled(): return global_batch_size, global_batch_size # Configure batch sizes. slots_per_trial = self.experiment_config.slots_per_trial() if global_batch_size < slots_per_trial: raise AssertionError( "Please set the `global_batch_size` hyperparameter to be greater or equal to the " f"number of slots. Current batch_size: {global_batch_size}, slots_per_trial: " f"{slots_per_trial}.") per_gpu_batch_size = global_batch_size // slots_per_trial effective_batch_size = per_gpu_batch_size * slots_per_trial if effective_batch_size != global_batch_size: logging.warning( f"`global_batch_size` changed from {global_batch_size} to {effective_batch_size} " f"to divide equally across {slots_per_trial} slots.") return per_gpu_batch_size, effective_batch_size
def run(self) -> None: """ A basic control loop of the old-style (callback-based) TrialController classes. """ for w, args, response_func in self.workloads: try: if w.kind == workload.Workload.Kind.RUN_STEP: response = self.train_for_step( w.step_id, w.num_batches) # type: workload.Response elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response = self.compute_validation_metrics(w.step_id) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) self.save(path) response = {} elif w.kind == workload.Workload.Kind.TERMINATE: self.terminate() response = workload.Skipped() else: raise AssertionError("Unexpected workload: {}".format( w.kind)) except det.errors.SkipWorkloadException: response = workload.Skipped() response_func(response)
def barrier(self, num_connections: int, message: Any = None, timeout: Optional[int] = None) -> List[Any]: """ This is a one-sided barrier, where the chief blocks until all non-chief trial containers have sent a message. """ check.eq(len(self.sockets), 1) messages = [] # type: List[Any] start_time = time.time() for _ in range(num_connections): if timeout: message_received, barrier_message = self.receive_non_blocking( send_rank=0, deadline=start_time + timeout) if not message_received: return messages else: barrier_message = self.receive_blocking(0) check.is_instance(barrier_message, _OneSidedBarrier) messages.append(barrier_message.message) self.sockets[0].send_pyobj(_OneSidedBarrier(message=message)) return messages
def __init__( self, context: Union[keras.TFKerasTrialContext, keras.TFKerasNativeContext], train_config: keras.TFKerasTrainConfig, ) -> None: super().__init__(context=context) self._training_cacheable = self._context.experimental.get_train_cacheable( ) self._training_dataset = train_config.training_data check.true( self._training_cacheable.is_decorator_used(), "Please use `@context.experimental.cache_train_dataset(dataset_name, dataset_version)`" " for the training dataset.", ) check.false( self._context.dataset_initialized, "Please do not use: `context.wrap_dataset(dataset)` if using " "`@context.experimental.cache_train_dataset()` and " "`@context.experimental.cache_validation_dataset()`.", ) check.is_instance( train_config.training_data, tf.data.Dataset, "Pass in a `tf.data.Dataset` object if using " "`@context.experimental.cache_train_dataset()`.", )
def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) check.is_instance(trial_inst, PyTorchTrial, "PyTorchTrialController needs an PyTorchTrial") self.trial = cast(PyTorchTrial, trial_inst) self._check_evaluate_implementation() self._init_model_and_optimizer() # Validation loader will be undefined on process ranks > 0 # when the user defines `validate_full_dataset()`. self.validation_loader = None # type: Optional[torch.utils.data.DataLoader] self._set_data_loaders() # Track whether a warning logging category has already been issued to the user. self.warning_logged = {_WarningLogs.FAILED_MOVING_TO_DEVICE: False} self.context.lr_scheduler = self.trial.create_lr_scheduler(self.context.optimizer) self.callbacks = self.trial.build_callbacks() # If a load path is provided load weights and restore the data location. self._load() self._configure_amp() if self.hvd_config.use: hvd.broadcast_parameters(self.context.model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(self.context.optimizer, root_rank=0) self.training_iterator = iter(self.training_loader)
def run(self) -> None: for w, args, response_func in self.workloads: if w.kind == workload.Workload.Kind.RUN_STEP: metrics = det.util.make_metrics( num_inputs=None, batch_metrics=[{ "loss": 1 } for _ in range(w.num_batches)], ) response_func({"metrics": metrics}) elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: check.len_eq(args, 0) response_func({ "metrics": { "validation_metrics": self.validation_metrics } }) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) if not path.exists(): path.mkdir(parents=True, exist_ok=True) with path.joinpath("a_file").open("w") as f: f.write("yup") response_func({}) elif w.kind == workload.Workload.Kind.TERMINATE: raise NotImplementedError()
def control_loop(self) -> None: for wkld, args, response_func in self.estimator_trial_controller.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Store values for the training loop. self.num_batches = wkld.num_batches self.train_response_func = response_func # Break out of the control loop so that the train process # re-enters the train_and_evaluate() loop. break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( det.util.wrap_metrics( self._compute_validation_metrics(), self.estimator_trial_controller.context. get_stop_requested(), )) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._checkpoint_model(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: self.estimator_trial_controller.exit_response_func = response_func raise det.errors.WorkerFinishedGracefully("Exiting normally.") else: raise AssertionError(f"Unknown wkld kind {wkld.kind}.") exit(1)
def from_trial( trial_inst: det.Trial, context: det.TrialContext, env: det.EnvContext, *args: Any, **kwargs: Any, ) -> det.TrialController: check.is_instance( context, estimator.EstimatorTrialContext, "EstimatorTrialController needs an EstimatorTrialContext", ) context = cast(estimator.EstimatorTrialContext, context) check.is_instance(trial_inst, EstimatorTrial, "EstimatorTrialController needs an EstimatorTrial") trial_inst = cast(EstimatorTrial, trial_inst) return EstimatorTrialController( trial_inst.build_estimator(), trial_inst.build_train_spec(), trial_inst.build_validation_spec(), trial_inst.build_serving_input_receiver_fns(), context, env, *args, **kwargs, )
def _control_loop(self) -> None: for wkld, args, response_func in self.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Configure the state for a training step. self.train_response_func = response_func self.train_workload_batches = 0 self.train_workload_metrics = [] self.train_workload_len = wkld.num_batches self.multiplexer.set_batches_requested(wkld.num_batches) break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( det.util.wrap_metrics(self._compute_validation_metrics(), self.context.get_stop_requested())) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) self.multiplexer._corrected_train_end() raise det.errors.WorkerFinishedGracefully else: raise AssertionError(f"Unknown workload kind {wkld.kind}.")
def _launch_evaluate(self) -> Any: validation_data = self.validation_data steps = None # Support the deprecated SequenceAdapter API. if isinstance(validation_data, keras.SequenceAdapter): # Ignore these settings and use the same settings as for the fit call. validation_data = validation_data.sequence if isinstance(validation_data, tf.keras.utils.Sequence): # Calculate the length of our validation shard. steps = len(validation_data) if self.context.distributed.get_size() > 1: size = self.context.distributed.get_size() rank = self.context.distributed.get_rank() steps = steps // size + (1 if steps % size > rank else 0) # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size. enqueuer = keras._build_enqueuer( sequence=validation_data, workers=self.context._fit_workers, use_multiprocessing=self.context._fit_use_multiprocessing, max_queue_size=self.context._fit_max_queue_size, shard_rank=self.context.distributed.get_rank(), num_shards=self.context.distributed.get_size(), repeat=False, shuffle=False, shuffle_seed=0, prior_batches_trained=0, ) enqueuer.start() self.enqueuers.append(enqueuer) validation_data = enqueuer.data() # Starting in TF 2.2 users may define custom test_step() that do # not use the model metrics. use_model_metrics = version.parse(tf.__version__) < version.parse("2.2.0") evaluate_kwargs = {} if use_model_metrics else {"return_dict": True} metrics_values = self.model.evaluate( validation_data, callbacks=self.callback_list, steps=steps, verbose=0, workers=0, **evaluate_kwargs, ) logging.debug(f"Worker finished model.evaluate() with metrics: {metrics_values}.") # If the model was compiled with metrics=None, metrics_value will be a single value. if not isinstance(metrics_values, (tuple, list, dict)): metrics_values = (metrics_values,) if use_model_metrics: metrics = make_logs(self.model, {}, metrics_values, ModeKeys.TEST, prefix="val_") else: check.is_instance(metrics_values, dict) metrics = {f"val_{k}": v for k, v in metrics_values.items()} return metrics
def from_native( context: det.NativeContext, env: det.EnvContext, workloads: workload.Stream, load_path: Optional[pathlib.Path], rendezvous_info: det.RendezvousInfo, hvd_config: horovod.HorovodContext, ) -> det.TrialController: check.is_instance( context, keras.TFKerasNativeContext, "TFKerasTrialController needs a TFKerasSprinkleContext", ) context = cast(keras.TFKerasNativeContext, context) check.is_not_none(context.model, "Please call wrap_model(...).") check.is_not_none(context.compile_args, "Please call model.compile(...).") check.is_not_none( context.train_config, "Please call model.fit(...) or model.fit_generator(...).", ) # For the Native API, we would break the user's model if we changed the session # right now, so we have to trust the user did not modify what we set previously. # # TODO(ryan): Fix this, probably with a function for configuring the backend session. session = tf.compat.v1.keras.backend.get_session() compile_args = cast(inspect.BoundArguments, context.compile_args) train_config = cast(keras.TFKerasTrainConfig, context.train_config) ( context.model, compile_args.arguments["optimizer"], ) = keras._get_multi_gpu_model_and_optimizer( pre_compiled_model=context.model, optimizer=compile_args.arguments["optimizer"], env=env, hvd_config=hvd_config, profile_frequency=env.experiment_config.profile_frequency(), profile_filename=DeterminedProfiler.OUTPUT_FILENAME, ) context.model.compile(*compile_args.args, **compile_args.kwargs) return TFKerasTrialController( context.model, session, train_config, context, env, workloads, load_path, rendezvous_info, hvd_config, )
def from_trial( trial_inst: det.Trial, context: det.TrialContext, env: det.EnvContext, workloads: workload.Stream, load_path: Optional[pathlib.Path], rendezvous_info: det.RendezvousInfo, hvd_config: horovod.HorovodContext, ) -> det.TrialController: check.is_instance( context, keras.TFKerasTrialContext, "TFKerasTrialController needs a TFKerasTrialContext") context = cast(keras.TFKerasTrialContext, context) check.is_instance(trial_inst, TFKerasTrial, "TFKerasTrialController needs a TFKerasTrial") trial = cast(TFKerasTrial, trial_inst) session = TFKerasTrialController._configure_session( env, hvd_config, trial.session_config()) training_data = keras._adapt_data_from_data_loader( input_data=trial.build_training_data_loader(), batch_size=context.get_per_slot_batch_size(), ) validation_data = keras._adapt_data_from_data_loader( input_data=trial.build_validation_data_loader(), batch_size=context.get_per_slot_batch_size(), ) trial.build_model() check.is_not_none(context.model, "Please call wrap_model(...).") check.is_not_none(context.compile_args, "Please call model.compile(...).") compile_args = cast(inspect.BoundArguments, context.compile_args) TFKerasTrialController.compile_model(context=context, compile_args=compile_args, env=env, hvd_config=hvd_config) tf_keras_callbacks = trial.keras_callbacks() return TFKerasTrialController( context.model, session, keras.TFKerasTrainConfig(training_data, validation_data, tf_keras_callbacks), context, env, workloads, load_path, rendezvous_info, hvd_config, )
def barrier(self, message: Any = None) -> Any: """ This is a one-sided barrier, where the chief blocks until all non-chief trial containers have sent a message. """ self.socket.send_pyobj(_OneSidedBarrier(message=message)) barrier_message = self.socket.recv_pyobj() check.is_instance(barrier_message, _OneSidedBarrier) return barrier_message.message
def _do_startup_message_sequence(self) -> None: # Wait for a ReadyMessage from every worker. responses = self.broadcast_server.gather_with_polling( self._health_check) for response in responses: check.is_instance( response, ipc.ReadyMessage, f"Did not receive ReadyMessage from worker. Got: {response}", )
def _do_startup_message_sequence(self) -> None: # Wait for a ReadyMessage from every worker. responses, exception_received = self.broadcast_server.gather_with_polling( self._health_check) if exception_received: raise det.errors.WorkerError("Training process died.") for response in responses: check.is_instance( response, ipc.ReadyMessage, f"Did not receive ReadyMessage from worker. Got: {response}", )
def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) check.is_instance(trial_inst, TensorpackTrial, "TensorpackTrialController needs a TensorpackTrial") self.trial = cast(TensorpackTrial, trial_inst) training_dataflow = self.trial.build_training_dataflow() validation_dataflow = self.trial.build_validation_dataflow() # Set if model is initialized from scratch. self.session_init = None # type: Optional[Any] self._init_model(training_dataflow, validation_dataflow)
def from_native( context: det.NativeContext, env: det.EnvContext, workloads: workload.Stream, load_path: Optional[pathlib.Path], rendezvous_info: det.RendezvousInfo, hvd_config: horovod.HorovodContext, ) -> det.TrialController: check.is_instance( context, keras.TFKerasNativeContext, "TFKerasTrialController needs a TFKerasSprinkleContext", ) context = cast(keras.TFKerasNativeContext, context) check.is_not_none(context.model, "Please call wrap_model(...).") check.is_not_none(context.compile_args, "Please call model.compile(...).") check.is_not_none( context.train_config, "Please call model.fit(...) or model.fit_generator(...).") # For the Native API, we would break the user's model if we changed the session # right now, so we have to trust the user did not modify what we set previously. # # TODO(ryan): Fix this, probably with a function for configuring the backend session. session = tf.compat.v1.keras.backend.get_session() compile_args = cast(inspect.BoundArguments, context.compile_args) train_config = cast(keras.TFKerasTrainConfig, context.train_config) TFKerasTrialController.compile_model(context=context, compile_args=compile_args, env=env, hvd_config=hvd_config) return TFKerasTrialController( context.model, session, train_config, context, env, workloads, load_path, rendezvous_info, hvd_config, )
def run(self) -> None: for wkld, args, response_func in self.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Store the train_response_func for later. self.train_response_func = response_func self.num_batches = wkld.num_batches # There are two possibilities when a RUN_STEP workload is recieved. # 1) This is the first training step seen by the trial # container. In this case, enter the tf.keras fit() training loop. # 2) This is _not_ the first training step, meaning that the # tf.keras fit() training loop is already active and paused. # break to re-enter the training loop. if not self.fit_loop_started: try: self._launch_fit() except det.errors.WorkerFinishedGracefully: pass if not self.expect_terminate: raise AssertionError( "Training loop exited unexpectedly but without throwing any errors. " "This is possibly due to a user callback causing the training loop to " "exit, which is not supported at this time." ) break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( det.util.wrap_metrics( self.compute_validation_metrics(), self.context.get_stop_requested() ) ) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: self.model.stop_training = True self.expect_terminate = True response_func({} if self.is_chief else workload.Skipped()) break else: raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
def _control_loop(self) -> None: for wkld, args, response_func in self.workloads: if wkld.kind == workload.Workload.Kind.RUN_STEP: # Move on to the next step. self.train_response_func = response_func break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func(self._compute_validation_metrics()) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self.save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: raise det.errors.WorkerFinishedGracefully("Exiting normally.") else: raise AssertionError(f"Unknown wkld kind {wkld.kind}")
def run(self) -> None: for w, args, response_func in self.workloads: if w.kind == workload.Workload.Kind.RUN_STEP: check.eq(len(args), 1) num_batches = cast(int, args[0]) response_func(self._train_for_step(w.step_id, num_batches)) elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func(self._compute_validation_metrics()) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.eq(len(args), 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save(path)) elif w.kind == workload.Workload.Kind.TERMINATE: break else: raise AssertionError("Unexpected workload: {}".format(w.kind))
def _send_recv_workload(self, wkld: workload.Workload, args: List[Any]) -> workload.Response: # Broadcast every workload to every worker on this machine. self.broadcast_server.broadcast((wkld, args)) if wkld.kind == workload.Workload.Kind.TERMINATE: # Do not perform health checks once worker have been instructed to terminate. self._worker_process_ids = [] try: responses, exception_received = self.broadcast_server.gather_with_polling( self._health_check) except det.errors.WorkerError: if wkld.kind == workload.Workload.Kind.TERMINATE: return {} raise if exception_received: raise det.errors.WorkerError("Training process died.") # Find the response from the chief worker for the trial (the only non-SkippedWorkload). The # chief may report to another container, in which case we will only have SkippedWorkloads. chief_worker_response = None # Optional[workload.Metrics] for response in responses: if isinstance(response, workload.Skipped): continue # Any other response must be a Dict[str, Any]-like object. check.is_instance( response, dict, f"Received non-metrics object from worker: {response}") # There should only be one chief response. check.is_none(chief_worker_response, "Received multiple non-SkippedWorkload messages.") chief_worker_response = cast(Dict[str, Any], response) # Confirm that if we have did not see a chief response then we are not the chief machine. if chief_worker_response is None: check.gt( self.rendezvous_info.get_rank(), 0, "Received SkippedWorkload message from chief worker.", ) return workload.Skipped( ) if chief_worker_response is None else chief_worker_response
def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) check.is_instance(trial_inst, PyTorchTrial, "PyTorchTrialController needs an PyTorchTrial") self.trial = cast(PyTorchTrial, trial_inst) self.context = cast(pytorch.PyTorchTrialContext, self.context) self.context.experimental._set_allgather_fn(self.allgather_metrics) self.callbacks = self.trial.build_callbacks() self._apply_backwards_compatibility() check.gt_eq( len(self.context.models), 1, "Must have at least one model. " "This might be caused by not wrapping your model with wrap_model()", ) check.gt_eq( len(self.context.optimizers), 1, "Must have at least one optimizer. " "This might be caused by not wrapping your optimizer with wrap_optimizer()", ) self._check_evaluate_implementation() # Validation loader will be undefined on process ranks > 0 # when the user defines `validate_full_dataset()`. self.validation_loader = None # type: Optional[torch.utils.data.DataLoader] self._set_data_loaders() # We don't want the training_iterator shuffling values after we load state self.training_iterator = iter(self.training_loader) # If a load path is provided load weights and restore the data location. self._load() if self.hvd_config.use: hvd.broadcast_parameters(self.context._main_model.state_dict(), root_rank=0) for optimizer in self.context.optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0)
def _launch_evaluate(self) -> Any: ( validation_data, validation_steps, ) = self._validation_input_manager.get_validation_input_and_num_batches( ) # Starting in TF 2.2 users may define custom test_step() that do # not use the model metrics. use_model_metrics = version.parse( tf.__version__) < version.parse("2.2.0") evaluate_kwargs = {} if use_model_metrics else {"return_dict": True} metrics_values = self.model.evaluate( validation_data, steps=validation_steps, verbose=0, callbacks=self.callback_list, **evaluate_kwargs, ) logging.debug( f"Worker finished model.evaluate() with metrics: {metrics_values}." ) # If the model was compiled with metrics=None, metrics_value will be a single value. if not isinstance(metrics_values, (tuple, list, dict)): metrics_values = (metrics_values, ) if use_model_metrics: metrics = make_logs(self.model, {}, metrics_values, ModeKeys.TEST, prefix="val_") else: check.is_instance(metrics_values, dict) metrics = {f"val_{k}": v for k, v in metrics_values.items()} _ = self._validation_input_manager.stop_validation_input_and_get_num_inputs( ) return metrics
def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) check.is_instance(trial_inst, PyTorchTrial, "PyTorchTrialController needs an PyTorchTrial") self.trial = cast(PyTorchTrial, trial_inst) self._check_evaluate_implementation() self.model = self.trial.build_model() # Validation loader will be undefined on process ranks > 0 # when the user defines `validate_full_dataset()`. self.validation_loader = None # type: Optional[torch.utils.data.DataLoader] self._set_data_loaders() # Track whether a warning logging category has already been issued to the user. self.warning_logged = {_WarningLogs.FAILED_MOVING_TO_DEVICE: False} self._init_model()
def _init_model(self) -> None: self._init_paths() self.estimator = tf.estimator.Estimator( model_fn=self.estimator._model_fn, config=self._init_run_config(self.estimator.config), params=self.estimator.params if self.estimator.params != {} else None, warm_start_from=self.estimator._warm_start_settings, ) check.is_instance( self.estimator, tf.estimator.Estimator, "Please modify your model definition's build_estimator() implementation to return " "an instance of `tf.estimator.Estimator`.", ) check.is_instance( self.user_train_spec, tf.estimator.TrainSpec, "Please modify your model definition's build_train_spec() implementation to return " "an instance of `tf.estimator.TrainSpec`.", ) check.is_instance( self.val_spec, tf.estimator.EvalSpec, "Please modify your model definition's build_validation_spec() implementation " "to return an instance of `tf.estimator.EvalSpec`.", ) all_hooks = [*self.user_train_spec.hooks] if self.hvd_config.use: all_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) # It is important that this hook is the final in the list so that if # any other hooks need to run _before_ the training step ends they have # their chance. all_hooks.append(DeterminedControlHook(self)) # TODO(DET-834): Separate step ID from data loader state. # # During warm start, we initialize model weights, optimizer state # and input state from the checkpoint, and we set the step ID to # 1. Trials typically use the step ID as an index into the data # sequence, which means there is an inconsistency between the # step ID (as data index) and the optimizer state and input state. # # In the short term, behave like other trials and reset input # state if we are warm started. This will create an inconsistency # wrt saved optimizer state. # Repeat training dataset so we never run out of data. repeating_train_fn = self._check_and_repeat_train_input_fn( self.user_train_spec.input_fn) self.train_spec = tf.estimator.TrainSpec(input_fn=repeating_train_fn, hooks=all_hooks) self.eval_spec = tf.estimator.EvalSpec(input_fn=self.val_spec.input_fn, steps=None)
def from_native(context: det.NativeContext, *args: Any, **kwargs: Any) -> det.TrialController: check.is_instance( context, estimator.EstimatorNativeContext, "EstimatorTrialController needs an EstimatorSprinkleContext", ) context = cast(estimator.EstimatorNativeContext, context) check.true( hasattr(context, "estimator") and hasattr(context, "train_spec") and hasattr(context, "eval_spec"), "Please call TFEstimatorExperiment.train_and_evaluate().", ) return EstimatorTrialController( context.estimator, context.train_spec, context.eval_spec, context.serving_input_receiver_fns, context, *args, **kwargs, )
def _init_model(self) -> None: self._init_train_hooks() self._init_val_hooks() self._init_paths() self.estimator = tf.estimator.Estimator( model_fn=self._set_default_session_before_building_model( self.estimator._model_fn), config=self._init_run_config(self.estimator.config), params=self.estimator.params if self.estimator.params != {} else None, warm_start_from=self.estimator._warm_start_settings, ) check.is_instance( self.estimator, tf.estimator.Estimator, "Please modify your model definition's build_estimator() implementation to return " "an instance of `tf.estimator.Estimator`.", ) check.is_instance( self.user_train_spec, tf.estimator.TrainSpec, "Please modify your model definition's build_train_spec() implementation to return " "an instance of `tf.estimator.TrainSpec`.", ) check.is_instance( self.val_spec, tf.estimator.EvalSpec, "Please modify your model definition's build_validation_spec() implementation " "to return an instance of `tf.estimator.EvalSpec`.", ) # TODO(DET-834): Separate step ID from data loader state. # # During warm start, we initialize model weights, optimizer state # and input state from the checkpoint, and we set the step ID to # 1. Trials typically use the step ID as an index into the data # sequence, which means there is an inconsistency between the # step ID (as data index) and the optimizer state and input state. # # In the short term, behave like other trials and reset input # state if we are warm started. This will create an inconsistency # wrt saved optimizer state. # Repeat training dataset so we never run out of data. repeating_train_fn = self._check_and_repeat_train_input_fn( self.user_train_spec.input_fn) self.train_spec = tf.estimator.TrainSpec(input_fn=repeating_train_fn, hooks=self.train_hooks) self.eval_spec = tf.estimator.EvalSpec(input_fn=self.val_spec.input_fn, hooks=self._init_val_hooks(), steps=self.val_spec.steps)