예제 #1
0
파일: model.py 프로젝트: lRomul/argus
    def save(self, file_path: types.Path, optimizer_state: bool = False):
        """Save the argus model into a file.

        The argus model is saved as a dict::

            {
                'model_name': Name of the argus model,
                'params': Argus model parameters dict,
                'nn_state_dict': torch nn_module.state_dict(),
                'optimizer_state_dict': torch optimizer.state_dict()
            }

        The *state_dict* is always transferred to CPU before saving.

        Args:
            file_path (str or :class:`pathlib.Path`): Path to the argus model
                file.
            optimizer_state (bool): Save optimizer state. Defaults to False.

        """
        nn_module = self.get_nn_module()
        state = {
            'model_name': self.__class__.__name__,
            'params': self.params,
            'nn_state_dict': deep_to(nn_module.state_dict(), 'cpu')
        }
        if optimizer_state and self.optimizer is not None:
            state['optimizer_state_dict'] = deep_to(
                self.optimizer.state_dict(), 'cpu')
        torch.save(state, file_path)
        self.logger.info(f"Model saved to '{file_path}'")
예제 #2
0
파일: model.py 프로젝트: vfdev-5/argus
    def save(self, file_path: Union[str, Path]):
        """Save the argus model into a file.

        The argus model is saved as a dict::

            {
                'model_name': Name of the argus model,
                'params': Argus model parameters dict,
                'nn_state_dict': torch nn_module.state_dict()
            }

        The *state_dict* is always transferred to cpu prior to saving.

        Args:
            file_path (str): Path to the argus model file.

        """
        nn_module = self.get_nn_module()
        state = {
            'model_name': self.__class__.__name__,
            'params': self.params,
            'nn_state_dict': deep_to(nn_module.state_dict(), 'cpu')
        }
        torch.save(state, file_path)
        self.logger.info(f"Model saved to '{file_path}'")
예제 #3
0
    def train_step(self, batch, state) -> dict:
        self.train()
        self.optimizer.zero_grad()

        # Gradient accumulation
        for i, chunk_batch in enumerate(deep_chunk(batch, self.iter_size)):
            input, target = deep_to(chunk_batch, self.device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=self.amp):
                prediction = self.nn_module(input)
                loss = self.loss(prediction, target)
                loss = loss / self.iter_size

            if self.amp:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()

        if self.amp:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)
        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }
예제 #4
0
    def train_step(self, batch, state: State) -> dict:
        state.input_batch = batch[0]
        state.engine.raise_event(CustomEvents.STEP_START)
        state.batch = None

        self.train()
        self.optimizer.zero_grad()
        input, target = deep_to(batch, device=self.device, non_blocking=True)
        prediction = self.nn_module(input)
        loss = self.loss(prediction, target)
        loss.backward()
        self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)

        state.prediction = prediction
        state.engine.raise_event(CustomEvents.STEP_COMPLETE)
        state.prediction = None

        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }
예제 #5
0
    def train_step(self, batch, state) -> dict:
        self.train()
        self.optimizer.zero_grad()

        for i, chunk_batch in enumerate(deep_chunk(batch, self.iter_size)):
            input, target = deep_to(chunk_batch,
                                    self.device,
                                    non_blocking=True)
            prediction = self.nn_module(input)
            loss = self.loss(prediction, target, training=True)
            if self.amp is not None:
                delay_unscale = i != (self.iter_size - 1)
                with self.amp.scale_loss(
                        loss, self.optimizer,
                        delay_unscale=delay_unscale) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

        self.optimizer.step()

        torch.cuda.synchronize()
        if self.model_ema is not None:
            with torch.no_grad():
                self.model_ema.update(self.nn_module)

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)
        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }
예제 #6
0
    def train_step(self, batch, state) -> dict:
        if not self.nn_module.training:
            self.nn_module.train()
        self.optimizer.zero_grad()
        input, target, noisy = deep_to(batch, self.device, non_blocking=True)
        prediction = self.nn_module(input)
        if self.aux_weights is not None:
            loss = 0
            for pred, weight in zip(prediction, self.aux_weights):
                loss += self.loss(pred, target, noisy) * weight
        else:
            loss = self.loss(prediction, target, noisy)
        if self.use_amp:
            with self.amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        return {
            'prediction': self.prediction_transform(prediction[0]),
            'target': target,
            'loss': loss.item(),
            'noisy': noisy
        }
예제 #7
0
파일: model.py 프로젝트: phamcuong92/argus
 def predict(self, input):
     assert self.predict_ready()
     with torch.no_grad():
         self.eval()
         input = deep_to(input, self.device)
         prediction = self.nn_module(input)
         prediction = self.prediction_transform(prediction)
         return prediction
예제 #8
0
파일: model.py 프로젝트: phamcuong92/argus
 def save(self, file_path):
     nn_module = self.get_nn_module()
     state = {
         'model_name': self.__class__.__name__,
         'params': self.params,
         'nn_state_dict': deep_to(nn_module.state_dict(), 'cpu')
     }
     torch.save(state, file_path)
     self.logger.info(f"Model saved to '{file_path}'")
예제 #9
0
파일: ema.py 프로젝트: lRomul/argus-alaska
    def save(self, file_path, argus_state):
        nn_module = argus_state.model.model_ema.ema
        if isinstance(nn_module, (DataParallel, DistributedDataParallel)):
            nn_module = nn_module.module

        no_ema_nn_module = argus_state.model.get_nn_module()
        if isinstance(no_ema_nn_module, DistributedDataParallel):
            no_ema_nn_module = no_ema_nn_module.module

        state = {
            'model_name': argus_state.model.__class__.__name__,
            'params': argus_state.model.params,
            'nn_state_dict': deep_to(nn_module.state_dict(), 'cpu'),
            'no_ema_nn_state_dict': deep_to(no_ema_nn_module.state_dict(),
                                            'cpu')
        }
        torch.save(state, file_path)
        argus_state.logger.info(f"Model saved to '{file_path}'")
예제 #10
0
 def predict(self, input):
     assert self.predict_ready()
     with torch.no_grad():
         if self.nn_module.training:
             self.nn_module.eval()
         input = deep_to(input, self.device)
         prediction = self.nn_module(input)
         if self.aux_weights is not None:
             prediction = prediction[0]
         prediction = self.prediction_transform(prediction)
         return prediction
예제 #11
0
 def val_step(self, batch, state) -> dict:
     self.eval()
     with torch.no_grad():
         input, target = deep_to(batch, self.device, non_blocking=True)
         if self.model_ema is None:
             prediction = self.nn_module(input)
         else:
             prediction = self.model_ema.ema(input)
         loss = self.loss(prediction, target)
         prediction = self.prediction_transform(prediction)
         return {
             'prediction': prediction,
             'target': target,
             'loss': loss.item()
         }
예제 #12
0
 def val_step(self, batch, state) -> dict:
     if self.nn_module.training:
         self.nn_module.eval()
     with torch.no_grad():
         input, target, noisy = deep_to(batch, self.device, non_blocking=True)
         prediction = self.nn_module(input)
         if self.aux_weights is not None:
             loss = 0
             for pred, weight in zip(prediction, self.aux_weights):
                 loss += self.loss(pred, target, noisy) * weight
         else:
             loss = self.loss(prediction, target, noisy)
         return {
             'prediction': self.prediction_transform(prediction[0]),
             'target': target,
             'loss': loss.item(),
             'noisy': noisy
         }
예제 #13
0
def test_deep_to(list_of_tensors, dict_of_tensors, destination_dtype):

    output_list = deep_to(list_of_tensors, dtype=destination_dtype)
    assert all([tensor.dtype == destination_dtype for tensor in output_list])

    output_dict = deep_to(dict_of_tensors, dtype=destination_dtype)
    assert all([isinstance(key, str) for key in output_dict.keys()])
    assert all([tensor.dtype == destination_dtype for tensor in output_dict.values()])

    nn_module = torch.nn.Linear(128, 8)
    output_nn_module = deep_to(nn_module, dtype=destination_dtype)
    assert output_nn_module.weight.dtype == destination_dtype

    assert 'qwerty' == deep_to('qwerty', dtype=destination_dtype)
    assert None is deep_to(None, dtype=destination_dtype)
    assert deep_to(True, dtype=destination_dtype)
예제 #14
0
파일: model.py 프로젝트: lRomul/argus
    def val_step(self, batch, state: State) -> dict:
        """Perform a single validation step.

        The method is used by :class:`argus.engine.Engine`.
        The validation step includes input and target tensor transferring to
        the model device, forward pass, loss evaluation, and the val batch
        prediction treating with a prediction_transform.

        Gradients calculation and the model weights update are omitted, which
        is the main difference with the :meth:`train_step`
        method.

        Args:
            batch (tuple of 2 torch.Tensors: (input, target)): The input data
                and target tensors to process.
            state (:class:`argus.engine.State`): The argus model state.

        Returns:
            dict: Default val step results::

                {
                    'prediction': The val batch predictions,
                    'target': The val batch target data on the model device,
                    'loss': The loss function value
                }

        """
        self.eval()
        with torch.no_grad():
            input, target = deep_to(batch,
                                    device=self.device,
                                    non_blocking=True)
            prediction = self.nn_module(input)
            loss = self.loss(prediction, target)
            prediction = self.prediction_transform(prediction)
            return {
                'prediction': prediction,
                'target': target,
                'loss': loss.item()
            }
예제 #15
0
파일: model.py 프로젝트: lRomul/argus
    def train_step(self, batch, state: State) -> dict:
        """Perform a single train step.

        The method is used by :class:`argus.engine.Engine`.
        The train step includes input and target tensor transferring to the
        model device, forward pass, loss evaluation, backward pass, and the
        train batch prediction treating with a prediction_transform.

        Args:
            batch (tuple of 2 torch.Tensors: (input, target)): The input and
                target tensors to process.
            state (:class:`argus.engine.State`): The argus model state.

        Returns:
            dict: The train step results::

                {
                    'prediction': The train batch predictions,
                    'target': The train batch target data on the model device,
                    'loss': The loss function value
                }

        """
        self.train()
        self.optimizer.zero_grad()
        input, target = deep_to(batch, device=self.device, non_blocking=True)
        prediction = self.nn_module(input)
        loss = self.loss(prediction, target)
        loss.backward()
        self.optimizer.step()

        prediction = deep_detach(prediction)
        target = deep_detach(target)
        prediction = self.prediction_transform(prediction)
        return {
            'prediction': prediction,
            'target': target,
            'loss': loss.item()
        }
예제 #16
0
파일: model.py 프로젝트: phamcuong92/argus
def load_model(file_path, device=None):
    if os.path.isfile(file_path):
        state = torch.load(file_path)

        if state['model_name'] in MODEL_REGISTRY:
            params = state['params']
            if device is not None:
                device = cast_device(device)
                device = device_to_str(device)
                params['device'] = device

            model_class = MODEL_REGISTRY[state['model_name']]
            model = model_class(params)
            nn_state_dict = deep_to(state['nn_state_dict'], model.device)

            model.get_nn_module().load_state_dict(nn_state_dict)
            model.eval()
            return model
        else:
            raise ImportError(
                f"Model '{state['model_name']}' not found in scope")
    else:
        raise FileNotFoundError(f"No state found at {file_path}")
예제 #17
0
파일: model.py 프로젝트: lRomul/argus
    def predict(self, input):
        """Make a prediction with the given input.

        The prediction process consists of the input tensor transferring to the
        model device, forward pass of the nn_module in *eval* mode and
        application of the prediction_transform to the raw prediction output.

        Args:
            input (torch.Tensor): The input tensor to predict with. It will be
                transferred to the model device. The user is responsible for
                ensuring that the input tensor shape and type match the model.

        Returns:
            torch.Tensor or other type: Predictions as the result of the
            prediction_transform application.

        """
        self._check_predict_ready()
        with torch.no_grad():
            self.eval()
            input = deep_to(input, self.device)
            prediction = self.nn_module(input)
            prediction = self.prediction_transform(prediction)
            return prediction
예제 #18
0
파일: model.py 프로젝트: vfdev-5/argus
def load_model(file_path: Union[str, Path],
               nn_module=default,
               optimizer=default,
               loss=default,
               prediction_transform=default,
               device=default,
               change_params_func=identity,
               change_state_dict_func=identity,
               model_name=default,
               **kwargs):
    """Load an argus model from a file.

    The function allows loading an argus model, saved with
    :meth:`argus.model.Model.save`. The model is always loaded in *eval* mode.

    Args:
        file_path (str): Path to the file to load.
        device (str or :class:`torch.device`, optional): Device for the model.
            Defaults to None.
        nn_module (dict, tuple or str, optional): Params of the nn_module to
            replace params in the state.
        optimizer (dict, tuple or str, optional): Params of the optimizer to
            replace params in the state. Set to `None` if don't want to create
            optimizer in the loaded model.
        loss (dict, tuple or str, optional): Params of the loss to replace params
            in the state. Set to `None` if don't want to create loss in the
            loaded model.
        prediction_transform (dict, tuple or str, optional): Params of the
            prediction_transform to replace params in the state. Set to `None`
            if don't want to create prediction_transform in the loaded model.
        change_params_func (function, optional): Function for modification of
            state params. Takes as input params from the loaded state, outputs
            params to model creation.
        change_state_dict_func (function, optional): Function for modification of
            nn_module state dict. Takes as input state dict from the loaded
            state, outputs state dict to model creation.
        model_name (str): Class name of :class:`argus.model.Model`.
            By default uses name from loaded state.

    Raises:
        ImportError: If the model is not available in the scope. Often it means
            that it is not imported or defined.
        FileNotFoundError: If the file is not found by the *file_path*.

    Returns:
        :class:`argus.model.Model`: Loaded argus model.

    """

    if os.path.isfile(file_path):
        state = torch.load(file_path)

        if model_name is default:
            model_name = state['model_name']
        else:
            model_name = model_name

        if model_name in MODEL_REGISTRY:
            params = state['params']
            if device is not default:
                device = cast_device(device)
                device = device_to_str(device)
                params['device'] = device

            if nn_module is not default:
                if nn_module is None:
                    raise ValueError("nn_module is required attribute for argus.Model")
                params['nn_module'] = nn_module
            if optimizer is not default:
                params['optimizer'] = optimizer
            if loss is not default:
                params['loss'] = loss
            if prediction_transform is not default:
                params['prediction_transform'] = prediction_transform

            for attribute, attribute_params in kwargs.items():
                params[attribute] = attribute_params

            model_class = MODEL_REGISTRY[model_name]
            params = change_params_func(params)
            model = model_class(params)
            nn_state_dict = deep_to(state['nn_state_dict'], model.device)
            nn_state_dict = change_state_dict_func(nn_state_dict)

            model.get_nn_module().load_state_dict(nn_state_dict)
            model.eval()
            return model
        else:
            raise ImportError(f"Model '{model_name}' not found in scope")
    else:
        raise FileNotFoundError(f"No state found at {file_path}")
예제 #19
0
 def prepare_batch(self, batch, device):
     input, target, noisy = batch
     input = deep_to(input, device, non_blocking=True)
     target = deep_to(target, device, non_blocking=True)
     noisy = deep_to(noisy, device, non_blocking=True)
     return input, target, noisy
예제 #20
0
def load_model(
        file_path: types.Path,
        nn_module: Union[Default, types.Param] = default,
        optimizer: Union[Default, None, types.Param] = default,
        loss: Union[Default, None, types.Param] = default,
        prediction_transform: Union[Default, None, types.Param] = default,
        device: Union[Default, types.InputDevices] = default,
        change_params_func: Callable = identity,
        change_state_dict_func: Callable = default_change_state_dict_func,
        model_name: Union[Default, str] = default,
        **kwargs):
    """Load an argus model from a file.

    The function allows loading an argus model, saved with
    :meth:`argus.model.Model.save`. The model is always loaded in *eval* mode.

    Args:
        file_path (str or :class:`pathlib.Path`): Path to the file to load.
        device (str, torch.device or list of devices, optional): Device for the model.
        nn_module (dict, tuple or str, optional): Params of the nn_module to
            replace params in the state.
        optimizer (None, dict, tuple or str, optional): Params of the optimizer to
            replace params in the state. Optimizer is not created in the loaded
            model if it is set to `None`.
        loss (None, dict, tuple or str, optional): Params of the loss to replace params
            in the state. Loss is not created in the loaded model if it is set
            to `None`.
        prediction_transform (None, dict, tuple or str, optional): Params of the
            prediction_transform to replace params in the state.
            prediction_transform is not created in the loaded model if it is
            set to `None`.
        change_params_func (function, optional): Function for modification of
            the loaded params. It takes params from the loaded state as an
            input and outputs params to use during the model creation.
        change_state_dict_func (function, optional): Function for modification of
            nn_module and optimizer state dict. Takes `nn_state_dict` and
            `optimizer_state_dict` as inputs and outputs state dicts for the
            model creation.
        model_name (str, optional): Class name of :class:`argus.model.Model`.
            By default uses the name from the loaded state.

    Returns:
        :class:`argus.model.Model`: Loaded argus model.

    Example:

        .. code-block:: python

            model = ArgusModel(params)
            model.save(model_path, optimizer_state=True)

            # restarting python...

            # ArgusModel class must be in scope at this moment
            model = argus.load_model(model_path, device="cuda:0")

        More options how to use load_model you can find
        `here <https://github.com/lRomul/argus/blob/master/examples/load_model.py>`_.

    Raises:
        ImportError: If the model is not available in the scope. Often it means
            that it is not imported or defined.
        FileNotFoundError: If the file is not found by the *file_path*.

    """

    if os.path.isfile(file_path):
        state = torch.load(file_path)

        if isinstance(model_name, Default):
            str_model_name = state['model_name']
        else:
            str_model_name = model_name

        if str_model_name in MODEL_REGISTRY:
            params = state['params']
            if not isinstance(device, Default):
                params['device'] = device_to_str(cast_device(device))

            if nn_module is not default:
                if nn_module is None:
                    raise ValueError(
                        "nn_module is required attribute for argus.Model")
                params['nn_module'] = nn_module
            if optimizer is not default:
                params['optimizer'] = optimizer
            if loss is not default:
                params['loss'] = loss
            if prediction_transform is not default:
                params['prediction_transform'] = prediction_transform

            for attribute, attribute_params in kwargs.items():
                params[attribute] = attribute_params

            model_class = MODEL_REGISTRY[str_model_name]
            params = change_params_func(params)
            model = model_class(params)
            nn_state_dict = deep_to(state['nn_state_dict'], model.device)
            optimizer_state_dict = None
            if 'optimizer_state_dict' in state:
                optimizer_state_dict = deep_to(state['optimizer_state_dict'],
                                               model.device)
            nn_state_dict, optimizer_state_dict = change_state_dict_func(
                nn_state_dict, optimizer_state_dict)

            model.get_nn_module().load_state_dict(nn_state_dict)
            if model.optimizer is not None and optimizer_state_dict is not None:
                model.optimizer.load_state_dict(optimizer_state_dict)
            model.eval()
            return model
        else:
            raise ImportError(f"Model '{model_name}' not found in scope")
    else:
        raise FileNotFoundError(f"No state found at {file_path}")
 def prepare_batch(self, batch, device):
     images, texts = batch["image"], batch["text"]
     output = (deep_to(images, device, non_blocking=True), texts)
     return output