Ejemplo n.º 1
0
 def test_not_torch(self, model):
     msg_match = r"^model must be a torch.nn.Module.*"
     with pytest.raises(TypeError, match=msg_match):
         model_validator.must_torch(model)
Ejemplo n.º 2
0
    def create_standard_model_from_torch(
        self,
        obj,
        environment,
        model_api=None,
        name=None,
        desc=None,
        labels=None,
        attrs=None,
        lock_level=None,
    ):
        """Create a Standard Verta Model version from a PyTorch model.

        .. versionadded:: 0.18.2

        Parameters
        ----------
        obj : `torch.nn.Module <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`__
            PyTorch model.
        environment : :class:`~verta.environment.Python`
            pip and apt dependencies.
        model_api : :class:`~verta.utils.ModelAPI`, optional
            Model API specifying the model's expected input and output
        name : str, optional
            Name of the model version. If no name is provided, one will be
            generated.
        desc : str, optional
            Description of the model version.
        labels : list of str, optional
            Labels of the model version.
        attrs : dict of str to {None, bool, float, int, str}, optional
            Attributes of the model version.
        lock_level : :mod:`~verta.registry.lock`, default :class:`~verta.registry.lock.Open`
            Lock level to set when creating this model version.

        Returns
        -------
        :class:`~verta.registry.entities.RegisteredModelVersion`

        Examples
        --------
        .. code-block:: python

            import torch
            from verta.environment import Python

            class Model(torch.nn.Module):
                def __init__(self):
                    super(Model, self).__init__()
                    self.layer1 = torch.nn.Linear(3, 2)
                    self.layer2 = torch.nn.Linear(2, 1)

                def forward(self, x):
                    x = torch.nn.functional.relu(self.layer1(x))
                    return torch.sigmoid(self.layer2(x))

            model = Model()
            train(model, data)

            model_ver = reg_model.create_standard_model_from_torch(
                model,
                Python(["torch"]),
            )
            endpoint.update(model_ver, wait=True)
            endpoint.get_deployed_model().predict(input)

        """
        model_validator.must_torch(obj)

        return self._create_standard_model_from_spec(
            model=obj,
            environment=environment,
            model_api=model_api,
            name=name,
            desc=desc,
            labels=labels,
            attrs=attrs,
            lock_level=lock_level,
        )
Ejemplo n.º 3
0
 def test_torch(self, model):
     assert model_validator.must_torch(model)