コード例 #1
0
def test_lite_module_wraps():
    """Test that the wrapped module is accessible via the property."""
    module = Mock()
    assert _LiteModule(module, Mock()).module is module

    wrapped_module = Mock()
    original_module = Mock()
    assert _LiteModule(
        wrapped_module, Mock(),
        original_module=original_module).module is original_module
コード例 #2
0
def test_lite_module_attribute_lookup():
    """Test that attribute lookup passes through to the original model when possible."""
    class OriginalModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(2, 3)
            self.attribute = 1

        def method(self):
            return 2

    original_module = OriginalModule()

    class ModuleWrapper(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.wrapped = original_module

    wrapped_module = ModuleWrapper()

    lite_module = _LiteModule(wrapped_module,
                              Mock(),
                              original_module=original_module)
    assert lite_module.attribute == 1
    assert lite_module.layer is original_module.layer
    assert lite_module.method() == 2
    assert lite_module.forward.__self__.__class__ == _LiteModule

    with pytest.raises(AttributeError):
        _ = lite_module.not_exists
コード例 #3
0
    def setup(
        self,
        model: nn.Module,
        *optimizers: Optimizer,
        move_to_device: bool = True,
    ) -> Any:  # no specific return because the way we want our API to look does not play well with mypy
        """Setup a model and its optimizers for accelerated training.

        Args:
            model: A model to setup
            *optimizers: The optimizer(s) to setup (no optimizers is also possible)
            move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
                and alternatively use :meth:`to_device` manually.

        Returns:
            The tuple of the wrapped model and list of optimizers, in the same order they were passed in.
        """
        self._validate_setup(model, optimizers)

        if move_to_device:
            model = self._move_model_to_device(model=model, optimizers=list(optimizers))

        # Let accelerator/plugin wrap and connect the models and optimizers
        model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers))
        model = _LiteModule(model, self._precision_plugin)
        optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
        self._models_setup += 1
        if optimizers:
            # join both types in a list for API convenience
            return [model] + optimizers  # type: ignore
        return model
コード例 #4
0
def test_lite_module_forward_conversion(precision, input_type, expected_type):
    """Test that the LiteModule performs autocasting on the input tensors and during forward()."""
    lite = EmptyLite(precision=precision, accelerator="gpu", devices=1)
    device = torch.device("cuda", 0)

    def check_autocast(forward_input):
        assert precision != 16 or torch.is_autocast_enabled()
        return forward_input

    module = Mock(wraps=torch.nn.Identity(), side_effect=check_autocast)
    lite_module = _LiteModule(module, lite._precision_plugin).to(device)
    out = lite_module(torch.tensor([1, 2, 3], dtype=input_type, device=device))
    assert module.call_args[0][0].dtype == expected_type
    assert out.dtype == input_type or out.dtype == torch.get_default_dtype()
コード例 #5
0
def test_lite_module_device_dtype_propagation(device, dtype):
    """Test that the LiteModule propagates device and dtype properties to its submodules (e.g. torchmetrics)."""
    class DeviceModule(DeviceDtypeModuleMixin):
        pass

    device_module = DeviceModule()
    lite_module = _LiteModule(device_module, Mock())
    lite_module.to(device)
    assert device_module.device == device
    assert lite_module.device == device

    lite_module.to(dtype)
    assert device_module.dtype == dtype
    assert lite_module.dtype == dtype