def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) check.is_instance(trial_inst, PyTorchTrial, "PyTorchTrialController needs an PyTorchTrial") self.trial = cast(PyTorchTrial, trial_inst) self._check_evaluate_implementation() self._init_model_and_optimizer() # Validation loader will be undefined on process ranks > 0 # when the user defines `validate_full_dataset()`. self.validation_loader = None # type: Optional[torch.utils.data.DataLoader] self._set_data_loaders() # Track whether a warning logging category has already been issued to the user. self.warning_logged = {_WarningLogs.FAILED_MOVING_TO_DEVICE: False} self.context.lr_scheduler = self.trial.create_lr_scheduler(self.context.optimizer) self.callbacks = self.trial.build_callbacks() # If a load path is provided load weights and restore the data location. self._load() self._configure_amp() if self.hvd_config.use: hvd.broadcast_parameters(self.context.model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(self.context.optimizer, root_rank=0) self.training_iterator = iter(self.training_loader)
def _init_model(self) -> None: self.optimizer = self.trial.optimizer(self.model) # TODO: Check that optimizer is not an amp optimizer. self._init_device() self.model = self.model.to(self.device) if self.hvd_config.use: use_compression = self.hvd_config.fp16_compression self.optimizer = hvd.DistributedOptimizer( self.optimizer, named_parameters=self.model.named_parameters(), backward_passes_per_step=self.hvd_config.aggregation_frequency, compression=hvd.Compression.fp16 if use_compression else hvd.Compression.none, ) logging.debug( "Initialized optimizer for distributed and optimized parallel training." ) elif self.n_gpus > 1: check.eq( self.hvd_config.aggregation_frequency, 1, "Please enable `optimized_parallel` to use aggregation " "frequency greater than 1 for single machine multi-GPU " "training.", ) self.model = nn.DataParallel(self.model) logging.debug("Initialized mode for native parallel training.") self.lr_helper = _LRHelper( self.trial.create_lr_scheduler(self.optimizer)) # If a load path is provided load weights and restore the data location. self._load() self._configure_amp() if self.hvd_config.use: hvd.broadcast_parameters(self.model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(self.optimizer, root_rank=0) # Initialize training and validation iterators. self.training_iterator = iter(self.training_loader)
def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) check.is_instance(trial_inst, PyTorchTrial, "PyTorchTrialController needs an PyTorchTrial") self.trial = cast(PyTorchTrial, trial_inst) self.context = cast(pytorch.PyTorchTrialContext, self.context) self.context.experimental._set_allgather_fn(self.allgather_metrics) self.callbacks = self.trial.build_callbacks() self._apply_backwards_compatibility() check.gt_eq( len(self.context.models), 1, "Must have at least one model. " "This might be caused by not wrapping your model with wrap_model()", ) check.gt_eq( len(self.context.optimizers), 1, "Must have at least one optimizer. " "This might be caused by not wrapping your optimizer with wrap_optimizer()", ) self._check_evaluate_implementation() # Validation loader will be undefined on process ranks > 0 # when the user defines `validate_full_dataset()`. self.validation_loader = None # type: Optional[torch.utils.data.DataLoader] self._set_data_loaders() # We don't want the training_iterator shuffling values after we load state self.training_iterator = iter(self.training_loader) # If a load path is provided load weights and restore the data location. self._load() if self.hvd_config.use: hvd.broadcast_parameters(self.context._main_model.state_dict(), root_rank=0) for optimizer in self.context.optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0)
def run(self) -> None: # We create the training_iterator here rather than in __init__ because we have to be careful # to trigger its shutdown explicitly, to avoid hangs in when the user is using # multiprocessing-based parallelism for their dataloader. # # We create it before loading state because we don't want the training_iterator shuffling # values after we load state. self.training_iterator = iter(self.training_loader) try: self._load() if self.hvd_config.use: hvd.broadcast_parameters(self.context._main_model.state_dict(), root_rank=0) for optimizer in self.context.optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) with self.prof: self._run() finally: # Explicitly trigger the training iterator's shutdown (which happens in __del__). # See the rather long note in pytorch/torch/utils/data/dataloader.py. del self.training_iterator
def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) check.is_instance(trial_inst, PyTorchTrial, "PyTorchTrialController needs an PyTorchTrial") self.trial = cast(PyTorchTrial, trial_inst) self.context = cast(PyTorchTrialContext, self.context) self.callbacks = self.trial.build_callbacks() # TODO(DET-3262): remove this backward compatibility of old interface. if (util.is_overridden(self.trial.build_model, PyTorchTrial) or util.is_overridden(self.trial.optimizer, PyTorchTrial) or util.is_overridden(self.trial.create_lr_scheduler, PyTorchTrial)): check.true( util.is_overridden(self.trial.build_model, PyTorchTrial) and util.is_overridden(self.trial.optimizer, PyTorchTrial), "Both build_model() and optimizer() must be defined " "if any of build_model(), optimizer(), and create_lr_scheduler() are defined. " "If you want to use the new interface, you should instead instantiate your models, " "optimizers, and LR schedulers in __init__ and call context.backward(loss) " "and context.step_optimizer(optimizer) in train_batch.", ) model = self.context._Model(self.trial.build_model()) optim = self.context._Optimizer(self.trial.optimizer(model)) lr_scheduler = self.trial.create_lr_scheduler(optim) if lr_scheduler is not None: self.context.lr_schedulers.append(lr_scheduler) if det.ExperimentConfig(self.context.get_experiment_config() ).mixed_precision_enabled(): self.context._configure_apex_amp( models=model, optimizers=optim, opt_level=self.context.get_experiment_config().get( "optimizations", {}).get("mixed_precision", "O0"), ) train_batch = self.trial.train_batch def new_train_batch( batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int) -> Union[torch.Tensor, Dict[str, Any]]: tr_metrics = train_batch(batch, model, epoch_idx, batch_idx) if isinstance(tr_metrics, torch.Tensor): tr_metrics = {"loss": tr_metrics} check.is_instance( tr_metrics, dict, "train_batch() must return a dictionary " f"mapping string names to Tensor metrics, got {type(tr_metrics)}", ) check.is_in("loss", tr_metrics.keys(), 'Please include "loss" in you training metrics.') def clip_grads(parameters: Iterator) -> None: for callback in self.callbacks.values(): callback.on_before_optimizer_step(parameters) self.context._backward(tr_metrics["loss"]) self.context._step_optimizer(self.context.optimizers[0], clip_grads=clip_grads) return tr_metrics self.trial.__setattr__("train_batch", new_train_batch) check.gt_eq( len(self.context.models), 1, "Must have at least one model. " "This might be caused by not wrapping your model with Model()", ) check.gt_eq( len(self.context.optimizers), 1, "Must have at least one optimizer. " "This might be caused by not wrapping your model with Optimizer()", ) self._check_evaluate_implementation() # Validation loader will be undefined on process ranks > 0 # when the user defines `validate_full_dataset()`. self.validation_loader = None # type: Optional[torch.utils.data.DataLoader] self._set_data_loaders() # If a load path is provided load weights and restore the data location. self._load() if self.hvd_config.use: hvd.broadcast_parameters(self.context._main_model.state_dict(), root_rank=0) for optimizer in self.context.optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) self.training_iterator = iter(self.training_loader)
def run(self) -> None: @contextlib.contextmanager def defer(fn: Callable, *args: Any) -> Iterator[None]: try: yield finally: fn(*args) # We define on_shutdown here instead of inside the `for callback in...` loop to ensure we # don't bind a the loop iteration variable `callback`, which would likely cause us to call # on_trial_shutdown() multiple times for the final callback, and not at all for the others. def on_shutdown(callback_name: str, on_trial_shutdown: Callable) -> None: with self.prof.record_timing( f"callbacks.{callback_name}.on_trial_shutdown"): on_trial_shutdown() with contextlib.ExitStack() as exit_stack: for callback in self.callbacks.values(): with self.prof.record_timing( f"callbacks.{callback.__class__.__name__}.on_trial_startup" ): callback.on_trial_startup(self.steps_completed, self.env.latest_checkpoint) exit_stack.enter_context( defer(on_shutdown, callback.__class__.__name__, callback.on_trial_shutdown)) self._set_data_loaders() # We create the training_iterator here rather than in __init__ because we have to be # careful to trigger its shutdown explicitly, to avoid hangs in when the user is using # multiprocessing-based parallelism for their dataloader. # # We create it before loading state because we don't want the training_iterator # shuffling values after we load state. self.training_iterator = iter(self.training_loader) def cleanup_iterator() -> None: # Explicitly trigger the training iterator's shutdown (which happens in __del__). # See the rather long note in pytorch/torch/utils/data/dataloader.py. del self.training_iterator exit_stack.enter_context(defer(cleanup_iterator)) # If a load path is provided load weights and restore the data location. if self.env.latest_checkpoint is not None: logging.info( f"Restoring trial from checkpoint {self.env.latest_checkpoint}" ) with self.context._core.checkpoint.restore_path( self.env.latest_checkpoint) as load_path: self._load(load_path) if self.context.distributed.size > 1 and self.use_horovod: hvd.broadcast_parameters(self.context._main_model.state_dict(), root_rank=0) for optimizer in self.context.optimizers: hvd.broadcast_optimizer_state(optimizer, root_rank=0) with self.prof: for callback in self.callbacks.values(): with self.prof.record_timing( f"callbacks.{callback.__class__.__name__}.on_training_start" ): callback.on_training_start() self._run()