Beispiel #1
0
    def _run_epoch(self, loaders):
        # @TODO: better solution with train/inference handling ?
        if not self.state.stage.startswith("infer"):
            assert self.state.valid_loader in loaders.keys(), \
                f"'{self.state.valid_loader}' " \
                f"should be in provided loaders: {list(loaders.keys())}"
        else:
            assert not any(x.startswith("train") for x in loaders.keys()), \
                "for inference no train loader should be passed"

        for loader_name, loader in loaders.items():
            self.state.loader_name = loader_name
            self.state.loader_len = len(loader)
            self.state.need_backward = loader_name.startswith("train")
            utils.maybe_recursive_call(self.model,
                                       "train",
                                       mode=self.state.need_backward)

            if isinstance(loader.sampler, DistributedSampler) \
                    and loader_name.startswith("train"):
                loader.sampler.set_epoch(self.state.stage_epoch)

            utils.set_global_seed(self.experiment.initial_seed +
                                  self.state.epoch + 1)
            self._run_event("loader_start")
            with torch.set_grad_enabled(self.state.need_backward):
                self._run_loader(loader)
            self._run_event("loader_end")
Beispiel #2
0
    def on_batch_end(self, state):
        if not state.need_backward:
            return

        loss = self._get_loss(state)

        self._accumulation_counter += 1
        model = state.model
        optimizer = state.get_key(key="optimizer",
                                  inner_key=self.optimizer_key)

        # This is very hacky check whether we have AMP optimizer and this may
        # change in future.
        # But alternative solution is to have AmpOptimizerCallback.
        # or expose another c'tor argument.
        if hasattr(optimizer, "_amp_stash"):
            from apex import amp
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if (self._accumulation_counter + 1) % self.accumulation_steps == 0:
            self.grad_step(optimizer=optimizer,
                           optimizer_wds=self._optimizer_wd,
                           grad_clip_fn=self.grad_clip_fn)
            maybe_recursive_call(model, "zero_grad")

            self._accumulation_counter = 0
Beispiel #3
0
    def predict_loader(
        self,
        *,
        loader: DataLoader,
        model: Model = None,
        resume: str = None,
        fp16: Union[Dict, bool] = None,
        initial_seed: int = 42,
    ) -> Generator:
        """
        Runs model inference on PyTorch Dataloader and returns
        python generator with model predictions from `runner.predict_batch`.
        Cleans up the experiment info to avoid possible collisions.
        Sets `is_train_loader` and `is_valid_loader` to `False` while
        keeping `is_infer_loader` as True. Moves model to evaluation mode.

        Args:
            loader: loader to predict
            model: model to use for prediction
            resume: path to checkpoint to resume
            fp16 (Union[Dict, bool]): fp16 usage flag
            initial_seed: seed to use before prediction

        Yields:
            bathes with model predictions
        """
        if isinstance(fp16, bool) and fp16:
            fp16 = {"opt_level": "O1"}

        if model is not None:
            self.model = model
        assert self.model is not None

        if resume is not None:
            checkpoint = utils.load_checkpoint(resume)
            utils.unpack_checkpoint(checkpoint, model=self.model)

        self.experiment = None
        utils.set_global_seed(initial_seed)
        (model, _, _, _, device) = utils.process_components(  # noqa: WPS122
            model=self.model,
            distributed_params=fp16,
            device=self.device,
        )
        self._prepare_inner_state(
            stage="infer",
            model=model,
            device=device,
            is_train_loader=False,
            is_valid_loader=False,
            is_infer_loader=True,
        )
        utils.maybe_recursive_call(self.model, "train", mode=False)

        utils.set_global_seed(initial_seed)
        for batch in loader:
            yield self.predict_batch(batch)
Beispiel #4
0
    def model(self, value: Union[Model, Dict[str, Model]]):
        """
        Setter for the runner's model'
        """
        if isinstance(value, nn.Module):
            model = value
        elif isinstance(value, dict):
            values_are_models = all(
                [isinstance(v, nn.Module) for v in value.values()])
            if not values_are_models:
                raise TypeError(
                    "Invalid dict value type, must be `torch.nn.Module`")

            model = value

        else:
            raise TypeError(
                f"Invalid value type "
                f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` "
                f"got '{type(value)}'")

        if self._device is not None:
            model: Model = utils.maybe_recursive_call(model,
                                                      "to",
                                                      device=self._device)

        self._model = model
Beispiel #5
0
def _process_trial_config(trial, config: Dict) -> Tuple[optuna.Trial, Dict]:
    def _eval_trial_suggestions(x):
        nonlocal trial
        if isinstance(x, str) and "trial.suggest_" in x:
            x = eval(x)
        return x

    config = utils.maybe_recursive_call(config, _eval_trial_suggestions)
    return trial, config
Beispiel #6
0
    def on_batch_end(self, state):
        """On batch end event"""
        if not state.need_backward:
            return

        loss = self._get_loss(state)

        self._accumulation_counter += 1
        model = state.model
        optimizer = state.get_key(key="optimizer",
                                  inner_key=self.optimizer_key)

        need_gradient_step = \
            (self._accumulation_counter + 1) % self.accumulation_steps == 0

        # This is very hacky check whether we have AMP optimizer and this may
        # change in future.
        # But alternative solution is to have AmpOptimizerCallback.
        # or expose another c'tor argument.
        if hasattr(optimizer, "_amp_stash"):
            from apex import amp
            # Need to set ``delay_unscale``
            # according to
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not need_gradient_step
            with amp.scale_loss(loss, optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if need_gradient_step:
            self.grad_step(optimizer=optimizer,
                           optimizer_wds=self._optimizer_wd,
                           grad_clip_fn=self.grad_clip_fn)

            if self.save_model_grads:
                for tag, value in model.named_parameters():
                    tag = tag.replace('.', '/')
                    state.model_grads[tag] = value.grad.cpu().numpy()

            maybe_recursive_call(model, "zero_grad")

            self._accumulation_counter = 0
Beispiel #7
0
    def device(self, value: Device):
        """
        Setter for the runner's device'
        """
        if isinstance(value, (str, torch.device)):
            self._device = value
        else:
            raise TypeError(f"Invalid value type "
                            f"must be `str` or `torch.device` "
                            f"got '{type(value)}'")

        if self._model is not None:
            self._model = utils.maybe_recursive_call(self._model,
                                                     "to",
                                                     device=self._device)