def _run_epoch(self, loaders): # @TODO: better solution with train/inference handling ? if not self.state.stage.startswith("infer"): assert self.state.valid_loader in loaders.keys(), \ f"'{self.state.valid_loader}' " \ f"should be in provided loaders: {list(loaders.keys())}" else: assert not any(x.startswith("train") for x in loaders.keys()), \ "for inference no train loader should be passed" for loader_name, loader in loaders.items(): self.state.loader_name = loader_name self.state.loader_len = len(loader) self.state.need_backward = loader_name.startswith("train") utils.maybe_recursive_call(self.model, "train", mode=self.state.need_backward) if isinstance(loader.sampler, DistributedSampler) \ and loader_name.startswith("train"): loader.sampler.set_epoch(self.state.stage_epoch) utils.set_global_seed(self.experiment.initial_seed + self.state.epoch + 1) self._run_event("loader_start") with torch.set_grad_enabled(self.state.need_backward): self._run_loader(loader) self._run_event("loader_end")
def on_batch_end(self, state): if not state.need_backward: return loss = self._get_loss(state) self._accumulation_counter += 1 model = state.model optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) # This is very hacky check whether we have AMP optimizer and this may # change in future. # But alternative solution is to have AmpOptimizerCallback. # or expose another c'tor argument. if hasattr(optimizer, "_amp_stash"): from apex import amp with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if (self._accumulation_counter + 1) % self.accumulation_steps == 0: self.grad_step(optimizer=optimizer, optimizer_wds=self._optimizer_wd, grad_clip_fn=self.grad_clip_fn) maybe_recursive_call(model, "zero_grad") self._accumulation_counter = 0
def predict_loader( self, *, loader: DataLoader, model: Model = None, resume: str = None, fp16: Union[Dict, bool] = None, initial_seed: int = 42, ) -> Generator: """ Runs model inference on PyTorch Dataloader and returns python generator with model predictions from `runner.predict_batch`. Cleans up the experiment info to avoid possible collisions. Sets `is_train_loader` and `is_valid_loader` to `False` while keeping `is_infer_loader` as True. Moves model to evaluation mode. Args: loader: loader to predict model: model to use for prediction resume: path to checkpoint to resume fp16 (Union[Dict, bool]): fp16 usage flag initial_seed: seed to use before prediction Yields: bathes with model predictions """ if isinstance(fp16, bool) and fp16: fp16 = {"opt_level": "O1"} if model is not None: self.model = model assert self.model is not None if resume is not None: checkpoint = utils.load_checkpoint(resume) utils.unpack_checkpoint(checkpoint, model=self.model) self.experiment = None utils.set_global_seed(initial_seed) (model, _, _, _, device) = utils.process_components( # noqa: WPS122 model=self.model, distributed_params=fp16, device=self.device, ) self._prepare_inner_state( stage="infer", model=model, device=device, is_train_loader=False, is_valid_loader=False, is_infer_loader=True, ) utils.maybe_recursive_call(self.model, "train", mode=False) utils.set_global_seed(initial_seed) for batch in loader: yield self.predict_batch(batch)
def model(self, value: Union[Model, Dict[str, Model]]): """ Setter for the runner's model' """ if isinstance(value, nn.Module): model = value elif isinstance(value, dict): values_are_models = all( [isinstance(v, nn.Module) for v in value.values()]) if not values_are_models: raise TypeError( "Invalid dict value type, must be `torch.nn.Module`") model = value else: raise TypeError( f"Invalid value type " f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` " f"got '{type(value)}'") if self._device is not None: model: Model = utils.maybe_recursive_call(model, "to", device=self._device) self._model = model
def _process_trial_config(trial, config: Dict) -> Tuple[optuna.Trial, Dict]: def _eval_trial_suggestions(x): nonlocal trial if isinstance(x, str) and "trial.suggest_" in x: x = eval(x) return x config = utils.maybe_recursive_call(config, _eval_trial_suggestions) return trial, config
def on_batch_end(self, state): """On batch end event""" if not state.need_backward: return loss = self._get_loss(state) self._accumulation_counter += 1 model = state.model optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) need_gradient_step = \ (self._accumulation_counter + 1) % self.accumulation_steps == 0 # This is very hacky check whether we have AMP optimizer and this may # change in future. # But alternative solution is to have AmpOptimizerCallback. # or expose another c'tor argument. if hasattr(optimizer, "_amp_stash"): from apex import amp # Need to set ``delay_unscale`` # according to # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations delay_unscale = not need_gradient_step with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale) as scaled_loss: scaled_loss.backward() else: loss.backward() if need_gradient_step: self.grad_step(optimizer=optimizer, optimizer_wds=self._optimizer_wd, grad_clip_fn=self.grad_clip_fn) if self.save_model_grads: for tag, value in model.named_parameters(): tag = tag.replace('.', '/') state.model_grads[tag] = value.grad.cpu().numpy() maybe_recursive_call(model, "zero_grad") self._accumulation_counter = 0
def device(self, value: Device): """ Setter for the runner's device' """ if isinstance(value, (str, torch.device)): self._device = value else: raise TypeError(f"Invalid value type " f"must be `str` or `torch.device` " f"got '{type(value)}'") if self._model is not None: self._model = utils.maybe_recursive_call(self._model, "to", device=self._device)