예제 #1
0
    def __init__(self,
                 shapes,
                 nb_pts,
                 gd,
                 tan,
                 cotan,
                 device=None,
                 dtype=None):
        self.__shapes = shapes
        self.__nb_pts = nb_pts

        self.__initialised = True

        # No tensors are filled
        if (gd is None) and (tan is None) and (cotan is None):
            # Default device set to cpu if device is not specified.
            if device is None:
                device = torch.device('cpu')

            # dtype set to Torch default if dtype is not specified.
            if dtype == None:
                dtype = torch.get_default_dtype()

            self.__initialised = False
        # Some tensors (or all) are filled
        else:
            # No device is specified, we infer it from the filled tensors.
            if device is None:
                device = tensors_device(flatten_tensor_list([gd, tan, cotan]))

                # Found device is None, meaning the filled tensors lives on different devices.
                if device is None:
                    raise RuntimeError(
                        "BaseManifold.__init__(): at least two initial manifold tensors live on different devices!"
                    )

            # No dtype is specified, we infer it from the filled tensors.
            if dtype is None:
                dtype = tensors_dtype(flatten_tensor_list([gd, tan, cotan]))

                # Found dtype is None, meaning the filled tensors are of different dtypes.
                if dtype is None:
                    raise RuntimeError(
                        "BaseManifold.__init__(): at least two initial manifold tensors are of different dtype!"
                    )

        self.__gd = TensorContainer(self.shape_gd, device, dtype)
        self.__tan = TensorContainer(self.shape_gd, device, dtype)
        self.__cotan = TensorContainer(self.shape_gd, device, dtype)

        if gd is not None:
            self.__gd.fill(gd, False, False)

        if tan is not None:
            self.__tan.fill(tan, False, False)

        if cotan is not None:
            self.__cotan.fill(cotan, False, False)

        super().__init__(device, dtype)
예제 #2
0
    def __init__(self, fields):
        dim = shared_tensors_property(fields, lambda x: x.dim)
        # device = shared_tensors_property(fields, lambda x: x.device)
        device = tensors_device(fields)
        assert dim is not None
        # assert device is not None
        super().__init__(dim, device)

        self.__fields = fields
예제 #3
0
    def fill(self, tensors, clone=False, requires_grad=None):
        """Fill the tensor container with **tensors**.

        Parameters
        ----------
        tensors : Iterable
            Iterable of torch.Tensor for multidimensional tensor or torch.Tensor for simple tensor we want to fill with.
        clone : bool, default=False
            Set to true to clone the tensors. This will detach the computation graph. If false, tensors will be passed as references.
        requires_grad : bool, default=None
            Set to true to record further operations on the tensors. Only relevant when cloning the tensors.
        """
        # if clone=False and requires_grad=None, should just assign tensor without changing requires_grad flag
        # assert (len(self.__shapes) == 1) or (isinstance(tensors, Iterable) and (len(self.__shapes)) == len(tensors))
        # assert (len(self.__shapes) > 1) or isinstance(tensors, torch.Tensor) and (len(self.__shapes) == 1)

        device = tensors_device(tensors)
        if device is None:
            raise RuntimeError(
                "BaseManifold.__ManifoldTensorContainer.fill(): at least two input tensors lives on different devices!"
            )

        self.__device = device

        dtype = tensors_dtype(tensors)
        if dtype is None:
            raise RuntimeError(
                "BaseManifold.__ManifoldTensorContainer.fill(): at least two input tensors are of different dtypes!"
            )

        self.__dtype = dtype

        if len(self.__shapes) == 1 and isinstance(tensors, torch.Tensor):
            tensors = (tensors, )

        if clone and requires_grad is not None:
            self.__tensors = tuple(
                tensor.detach().clone().requires_grad_(requires_grad)
                for tensor in tensors)
        elif clone and requires_grad is None:
            self.__tensors = tuple(tensor.clone() for tensor in tensors)
        else:
            self.__tensors = tuple(tensor for tensor in tensors)
예제 #4
0
 def __init__(self, manifolds):
     self.__manifolds = manifolds
     device = tensors_device(self.__manifolds)
     dtype = tensors_dtype(self.__manifolds)
     super().__init__(device, dtype)