Exemplo n.º 1
0
    def set_device(self, device: types.InputDevices):
        """Move nn_module and loss to the specified device.

        If a list of devices is passed, :class:`torch.nn.DataParallel` will be
        used. Batch tensors will be scattered on dim 0. The first device in the
        list is the location of the output. By default, device "cuda" is the
        GPU training on :func:`torch.cuda.current_device`.

        Example:

            .. code-block:: python

                model.set_device("cuda")
                model.set_device(torch.device("cuda"))

                model.set_device("cuda:0")
                model.set_device(["cuda:2", "cuda:3"])  # Use DataParallel

                model.set_device([torch.device("cuda:2"),
                                  torch.device("cuda", index=3)])

        Args:
            device (str, torch.device or list of devices): A device or list of
                devices.

        """
        torch_device = cast_device(device)
        nn_module = self.get_nn_module()

        if isinstance(torch_device, (list, tuple)):
            device_ids = get_device_indices(torch_device)
            nn_module = DataParallel(nn_module, device_ids=device_ids)
            output_device = torch_device[0]
        else:
            output_device = torch_device

        self.nn_module = nn_module.to(output_device)
        if self.loss is not None:
            self.loss = self.loss.to(output_device)
        self.params['device'] = device_to_str(torch_device)
        self.device = output_device
Exemplo n.º 2
0
def load_model(file_path, device=None):
    if os.path.isfile(file_path):
        state = torch.load(file_path)

        if state['model_name'] in MODEL_REGISTRY:
            params = state['params']
            if device is not None:
                device = cast_device(device)
                device = device_to_str(device)
                params['device'] = device

            model_class = MODEL_REGISTRY[state['model_name']]
            model = model_class(params)
            nn_state_dict = deep_to(state['nn_state_dict'], model.device)

            model.get_nn_module().load_state_dict(nn_state_dict)
            model.eval()
            return model
        else:
            raise ImportError(
                f"Model '{state['model_name']}' not found in scope")
    else:
        raise FileNotFoundError(f"No state found at {file_path}")
Exemplo n.º 3
0
    def set_device(self, device):
        device = cast_device(device)
        str_device = device_to_str(device)
        nn_module = self.get_nn_module()

        if isinstance(device, (list, tuple)):
            device_ids = []
            for dev in device:
                if dev.type != 'cuda':
                    raise ValueError
                if dev.index is None:
                    raise ValueError
                device_ids.append(dev.index)
            if len(device_ids) != len(set(device_ids)):
                raise ValueError("Cuda device indices must be unique")
            nn_module = DataParallel(nn_module, device_ids=device_ids)
            device = device[0]

        self.params['device'] = str_device
        self.device = device
        self.nn_module = nn_module.to(self.device)
        if self.loss is not default:
            self.loss = self.loss.to(self.device)
Exemplo n.º 4
0
def load_model(
        file_path: types.Path,
        nn_module: Union[Default, types.Param] = default,
        optimizer: Union[Default, None, types.Param] = default,
        loss: Union[Default, None, types.Param] = default,
        prediction_transform: Union[Default, None, types.Param] = default,
        device: Union[Default, types.InputDevices] = default,
        change_params_func: Callable = identity,
        change_state_dict_func: Callable = default_change_state_dict_func,
        model_name: Union[Default, str] = default,
        **kwargs):
    """Load an argus model from a file.

    The function allows loading an argus model, saved with
    :meth:`argus.model.Model.save`. The model is always loaded in *eval* mode.

    Args:
        file_path (str or :class:`pathlib.Path`): Path to the file to load.
        device (str, torch.device or list of devices, optional): Device for the model.
        nn_module (dict, tuple or str, optional): Params of the nn_module to
            replace params in the state.
        optimizer (None, dict, tuple or str, optional): Params of the optimizer to
            replace params in the state. Optimizer is not created in the loaded
            model if it is set to `None`.
        loss (None, dict, tuple or str, optional): Params of the loss to replace params
            in the state. Loss is not created in the loaded model if it is set
            to `None`.
        prediction_transform (None, dict, tuple or str, optional): Params of the
            prediction_transform to replace params in the state.
            prediction_transform is not created in the loaded model if it is
            set to `None`.
        change_params_func (function, optional): Function for modification of
            the loaded params. It takes params from the loaded state as an
            input and outputs params to use during the model creation.
        change_state_dict_func (function, optional): Function for modification of
            nn_module and optimizer state dict. Takes `nn_state_dict` and
            `optimizer_state_dict` as inputs and outputs state dicts for the
            model creation.
        model_name (str, optional): Class name of :class:`argus.model.Model`.
            By default uses the name from the loaded state.

    Returns:
        :class:`argus.model.Model`: Loaded argus model.

    Example:

        .. code-block:: python

            model = ArgusModel(params)
            model.save(model_path, optimizer_state=True)

            # restarting python...

            # ArgusModel class must be in scope at this moment
            model = argus.load_model(model_path, device="cuda:0")

        More options how to use load_model you can find
        `here <https://github.com/lRomul/argus/blob/master/examples/load_model.py>`_.

    Raises:
        ImportError: If the model is not available in the scope. Often it means
            that it is not imported or defined.
        FileNotFoundError: If the file is not found by the *file_path*.

    """

    if os.path.isfile(file_path):
        state = torch.load(file_path)

        if isinstance(model_name, Default):
            str_model_name = state['model_name']
        else:
            str_model_name = model_name

        if str_model_name in MODEL_REGISTRY:
            params = state['params']
            if not isinstance(device, Default):
                params['device'] = device_to_str(cast_device(device))

            if nn_module is not default:
                if nn_module is None:
                    raise ValueError(
                        "nn_module is required attribute for argus.Model")
                params['nn_module'] = nn_module
            if optimizer is not default:
                params['optimizer'] = optimizer
            if loss is not default:
                params['loss'] = loss
            if prediction_transform is not default:
                params['prediction_transform'] = prediction_transform

            for attribute, attribute_params in kwargs.items():
                params[attribute] = attribute_params

            model_class = MODEL_REGISTRY[str_model_name]
            params = change_params_func(params)
            model = model_class(params)
            nn_state_dict = deep_to(state['nn_state_dict'], model.device)
            optimizer_state_dict = None
            if 'optimizer_state_dict' in state:
                optimizer_state_dict = deep_to(state['optimizer_state_dict'],
                                               model.device)
            nn_state_dict, optimizer_state_dict = change_state_dict_func(
                nn_state_dict, optimizer_state_dict)

            model.get_nn_module().load_state_dict(nn_state_dict)
            if model.optimizer is not None and optimizer_state_dict is not None:
                model.optimizer.load_state_dict(optimizer_state_dict)
            model.eval()
            return model
        else:
            raise ImportError(f"Model '{model_name}' not found in scope")
    else:
        raise FileNotFoundError(f"No state found at {file_path}")
Exemplo n.º 5
0
def load_model(file_path: Union[str, Path],
               nn_module=default,
               optimizer=default,
               loss=default,
               prediction_transform=default,
               device=default,
               change_params_func=identity,
               change_state_dict_func=identity,
               model_name=default,
               **kwargs):
    """Load an argus model from a file.

    The function allows loading an argus model, saved with
    :meth:`argus.model.Model.save`. The model is always loaded in *eval* mode.

    Args:
        file_path (str): Path to the file to load.
        device (str or :class:`torch.device`, optional): Device for the model.
            Defaults to None.
        nn_module (dict, tuple or str, optional): Params of the nn_module to
            replace params in the state.
        optimizer (dict, tuple or str, optional): Params of the optimizer to
            replace params in the state. Set to `None` if don't want to create
            optimizer in the loaded model.
        loss (dict, tuple or str, optional): Params of the loss to replace params
            in the state. Set to `None` if don't want to create loss in the
            loaded model.
        prediction_transform (dict, tuple or str, optional): Params of the
            prediction_transform to replace params in the state. Set to `None`
            if don't want to create prediction_transform in the loaded model.
        change_params_func (function, optional): Function for modification of
            state params. Takes as input params from the loaded state, outputs
            params to model creation.
        change_state_dict_func (function, optional): Function for modification of
            nn_module state dict. Takes as input state dict from the loaded
            state, outputs state dict to model creation.
        model_name (str): Class name of :class:`argus.model.Model`.
            By default uses name from loaded state.

    Raises:
        ImportError: If the model is not available in the scope. Often it means
            that it is not imported or defined.
        FileNotFoundError: If the file is not found by the *file_path*.

    Returns:
        :class:`argus.model.Model`: Loaded argus model.

    """

    if os.path.isfile(file_path):
        state = torch.load(file_path)

        if model_name is default:
            model_name = state['model_name']
        else:
            model_name = model_name

        if model_name in MODEL_REGISTRY:
            params = state['params']
            if device is not default:
                device = cast_device(device)
                device = device_to_str(device)
                params['device'] = device

            if nn_module is not default:
                if nn_module is None:
                    raise ValueError("nn_module is required attribute for argus.Model")
                params['nn_module'] = nn_module
            if optimizer is not default:
                params['optimizer'] = optimizer
            if loss is not default:
                params['loss'] = loss
            if prediction_transform is not default:
                params['prediction_transform'] = prediction_transform

            for attribute, attribute_params in kwargs.items():
                params[attribute] = attribute_params

            model_class = MODEL_REGISTRY[model_name]
            params = change_params_func(params)
            model = model_class(params)
            nn_state_dict = deep_to(state['nn_state_dict'], model.device)
            nn_state_dict = change_state_dict_func(nn_state_dict)

            model.get_nn_module().load_state_dict(nn_state_dict)
            model.eval()
            return model
        else:
            raise ImportError(f"Model '{model_name}' not found in scope")
    else:
        raise FileNotFoundError(f"No state found at {file_path}")
Exemplo n.º 6
0
def test_device_to_str():
    assert 'cpu' == device_to_str(torch.device('cpu'))
    devices = torch.device('cuda:0'), torch.device('cuda:1')
    assert ['cuda:0', 'cuda:1'] == device_to_str(devices)