def _compute_validation_metrics(self) -> workload.Response: self.context.experimental.reset_reducers() # Set the behavior of certain layers (e.g., dropout) that are # different between training and inference. for model in self.context.models: model.eval() for callback in self.callbacks.values(): logging.warning( "on_validation_step_start is now deprecated, please use on_validation_start instead" ) callback.on_validation_step_start() for callback in self.callbacks.values(): callback.on_validation_start() num_inputs = 0 metrics = {} # type: Dict[str, Any] if self._evaluate_batch_defined(): keys = None batch_metrics = [] self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader) check.gt(len(self.validation_loader), 0) for batch in self.validation_loader: batch = self.context.to_device(batch) num_inputs += pytorch.data_length(batch) vld_metrics = self.trial.evaluate_batch(batch=batch) # Verify validation metric names are the same across batches. if keys is None: keys = vld_metrics.keys() else: check.eq( keys, vld_metrics.keys(), "Validation metric names must match across all batches of data.", ) check.is_instance( vld_metrics, dict, "validation_metrics() must return a " "dictionary of string names to Tensor " "metrics", ) # TODO: For performance perform -> cpu() only at the end of validation. batch_metrics.append( self._convert_metrics_to_numpy(vld_metrics)) if self.env.test_mode: break metrics = self._reduce_metrics( batch_metrics=batch_metrics, keys=keys, metrics_reducers=self._prepare_metrics_reducers(keys=keys), ) if self.hvd_config.use: num_inputs *= hvd.size() else: check.true(self._evaluate_full_dataset_defined()) self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader) if self.is_chief: metrics = self.trial.evaluate_full_dataset( data_loader=self.validation_loader) check.is_instance( metrics, dict, f"eval() must return a dictionary, got {type(metrics)}.") metrics = self._convert_metrics_to_numpy(metrics) num_inputs = self.context.get_per_slot_batch_size() * len( self.validation_loader) metrics.update( self._convert_metrics_to_numpy( self.context.experimental.reduce_metrics(for_training=False))) if self.hvd_config.use and any( map( lambda c: util.is_overridden( c.on_validation_end, pytorch. PyTorchCallback) or util.is_overridden( c.on_validation_step_end, pytorch.PyTorchCallback), self.callbacks.values(), )): logging.debug( "Broadcasting metrics to all worker processes to execute a " "validation step end callback") metrics = hvd.broadcast_object(metrics, root_rank=0) for callback in self.callbacks.values(): logging.warning( "on_validation_step_end is now deprecated, please use on_validation_end instead" ) callback.on_validation_step_end(metrics) for callback in self.callbacks.values(): callback.on_validation_end(metrics) if not self.is_chief: return workload.Skipped() return {"num_inputs": num_inputs, "validation_metrics": metrics}
def _compute_validation_metrics(self) -> workload.Response: self.context.reset_reducers() # Set the behavior of certain layers (e.g., dropout) that are # different between training and inference. for model in self.context.models: model.eval() step_start_time = time.time() for callback in self.callbacks.values(): if util.is_overridden(callback.on_validation_step_start, pytorch.PyTorchCallback): logging.warning("on_validation_step_start is now deprecated, " "please use on_validation_start instead") callback.on_validation_step_start() for callback in self.callbacks.values(): callback.on_validation_start() num_inputs = 0 metrics = {} # type: Dict[str, Any] if self._evaluate_batch_defined(): keys = None batch_metrics = [] self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader) check.gt(len(self.validation_loader), 0) for callback in self.callbacks.values(): callback.on_validation_epoch_start() for idx, batch in enumerate(self.validation_loader): if self.context.experimental._auto_to_device: batch = self.context.to_device(batch) num_inputs += self.trial.get_batch_length(batch) if has_param(self.trial.evaluate_batch, "batch_idx", 2): vld_metrics = self.trial.evaluate_batch(batch=batch, batch_idx=idx) else: vld_metrics = self.trial.evaluate_batch( batch=batch) # type: ignore # Verify validation metric names are the same across batches. if keys is None: keys = vld_metrics.keys() else: check.eq( keys, vld_metrics.keys(), "Validation metric names must match across all batches of data.", ) check.is_instance( vld_metrics, dict, "validation_metrics() must return a " "dictionary of string names to Tensor " "metrics", ) # TODO: For performance perform -> cpu() only at the end of validation. batch_metrics.append( pytorch._convert_metrics_to_numpy(vld_metrics)) if self.env.test_mode: break for callback in self.callbacks.values(): callback.on_validation_epoch_end(batch_metrics) metrics = pytorch._reduce_metrics( self.context.distributed, batch_metrics=batch_metrics, keys=keys, metrics_reducers=pytorch._prepare_metrics_reducers( self.trial.evaluation_reducer(), keys=keys), ) # Gather a list of per-worker (num_inputs, num_batches) tuples. input_counts = self.context.distributed.gather( (num_inputs, idx + 1)) if self.context.distributed.rank == 0: assert input_counts is not None # Reshape and sum. num_inputs, num_batches = [sum(n) for n in zip(*input_counts)] else: check.true(self._evaluate_full_dataset_defined()) self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader) if self.is_chief: metrics = self.trial.evaluate_full_dataset( data_loader=self.validation_loader) check.is_instance( metrics, dict, f"eval() must return a dictionary, got {type(metrics)}.") metrics = pytorch._convert_metrics_to_numpy(metrics) num_inputs = self.context.get_per_slot_batch_size() * len( self.validation_loader) metrics.update( pytorch._convert_metrics_to_numpy( self.context.reduce_metrics(for_training=False))) if self.context.distributed.size > 1 and any( map( lambda c: util.is_overridden( c.on_validation_end, pytorch. PyTorchCallback) or util.is_overridden( c.on_validation_step_end, pytorch.PyTorchCallback), self.callbacks.values(), )): logging.debug( "Broadcasting metrics to all worker processes to execute a " "validation step end callback") metrics = hvd.broadcast_object(metrics, root_rank=0) for callback in self.callbacks.values(): if util.is_overridden(callback.on_validation_step_end, pytorch.PyTorchCallback): logging.warning( "on_validation_step_end is now deprecated, please use on_validation_end instead" ) callback.on_validation_step_end(metrics) for callback in self.callbacks.values(): callback.on_validation_end(metrics) if not self.is_chief: return {} # Skip reporting timings if evaluate_full_dataset() was defined. This is far less common # than evaluate_batch() and we can't know how the user processed their validation data. if self._evaluate_batch_defined(): step_duration = time.time() - step_start_time logging.info( det.util.make_timing_log("validated", step_duration, num_inputs, num_batches)) return {"num_inputs": num_inputs, "validation_metrics": metrics}