def __init__(self, shapes, nb_pts, gd, tan, cotan, device=None, dtype=None): self.__shapes = shapes self.__nb_pts = nb_pts self.__initialised = True # No tensors are filled if (gd is None) and (tan is None) and (cotan is None): # Default device set to cpu if device is not specified. if device is None: device = torch.device('cpu') # dtype set to Torch default if dtype is not specified. if dtype == None: dtype = torch.get_default_dtype() self.__initialised = False # Some tensors (or all) are filled else: # No device is specified, we infer it from the filled tensors. if device is None: device = tensors_device(flatten_tensor_list([gd, tan, cotan])) # Found device is None, meaning the filled tensors lives on different devices. if device is None: raise RuntimeError( "BaseManifold.__init__(): at least two initial manifold tensors live on different devices!" ) # No dtype is specified, we infer it from the filled tensors. if dtype is None: dtype = tensors_dtype(flatten_tensor_list([gd, tan, cotan])) # Found dtype is None, meaning the filled tensors are of different dtypes. if dtype is None: raise RuntimeError( "BaseManifold.__init__(): at least two initial manifold tensors are of different dtype!" ) self.__gd = TensorContainer(self.shape_gd, device, dtype) self.__tan = TensorContainer(self.shape_gd, device, dtype) self.__cotan = TensorContainer(self.shape_gd, device, dtype) if gd is not None: self.__gd.fill(gd, False, False) if tan is not None: self.__tan.fill(tan, False, False) if cotan is not None: self.__cotan.fill(cotan, False, False) super().__init__(device, dtype)
def __init__(self, fields): dim = shared_tensors_property(fields, lambda x: x.dim) # device = shared_tensors_property(fields, lambda x: x.device) device = tensors_device(fields) assert dim is not None # assert device is not None super().__init__(dim, device) self.__fields = fields
def fill(self, tensors, clone=False, requires_grad=None): """Fill the tensor container with **tensors**. Parameters ---------- tensors : Iterable Iterable of torch.Tensor for multidimensional tensor or torch.Tensor for simple tensor we want to fill with. clone : bool, default=False Set to true to clone the tensors. This will detach the computation graph. If false, tensors will be passed as references. requires_grad : bool, default=None Set to true to record further operations on the tensors. Only relevant when cloning the tensors. """ # if clone=False and requires_grad=None, should just assign tensor without changing requires_grad flag # assert (len(self.__shapes) == 1) or (isinstance(tensors, Iterable) and (len(self.__shapes)) == len(tensors)) # assert (len(self.__shapes) > 1) or isinstance(tensors, torch.Tensor) and (len(self.__shapes) == 1) device = tensors_device(tensors) if device is None: raise RuntimeError( "BaseManifold.__ManifoldTensorContainer.fill(): at least two input tensors lives on different devices!" ) self.__device = device dtype = tensors_dtype(tensors) if dtype is None: raise RuntimeError( "BaseManifold.__ManifoldTensorContainer.fill(): at least two input tensors are of different dtypes!" ) self.__dtype = dtype if len(self.__shapes) == 1 and isinstance(tensors, torch.Tensor): tensors = (tensors, ) if clone and requires_grad is not None: self.__tensors = tuple( tensor.detach().clone().requires_grad_(requires_grad) for tensor in tensors) elif clone and requires_grad is None: self.__tensors = tuple(tensor.clone() for tensor in tensors) else: self.__tensors = tuple(tensor for tensor in tensors)
def __init__(self, manifolds): self.__manifolds = manifolds device = tensors_device(self.__manifolds) dtype = tensors_dtype(self.__manifolds) super().__init__(device, dtype)