Beispiel #1
0
class ImageConstant_Torch(input.ImageConstant):
    def __init__(self,
                 image_path: str,
                 imsize: int = 512,
                 name: str = "ImageConstant_Torch",
                 device=torch.device("cpu")):
        super(ImageConstant_Torch, self).__init__(name)
        loader = transforms.Compose(
            [transforms.Resize(imsize),
             transforms.ToTensor()])
        image = loader(Image.open(image_path, mode="r")).unsqueeze(0)
        self.linked_tensor_torch = Tensor_Torch(image.to(device, torch.float),
                                                name=self.name +
                                                "_image_tensor")

    def get_saved_tensor(self):
        return self.linked_tensor_torch

    def set_device(self, device: torch.device):
        self.linked_tensor_torch.set_device(device=device)

    def get_device(self):
        return self.linked_tensor_torch.get_device()

    # return KB in memory usage for the loaded tensor
    def get_tensor_memory_size(self):
        return self.linked_tensor_torch.get_self_memory_size()

    # return KB in memory usage for gradients of the loaded tensor
    def get_tensor_grad_memory_size(self):
        return self.linked_tensor_torch.get_grad_memory_size()

    def remove_from_tracking_gradient(self):
        return self.linked_tensor_torch.remove_from_tracking_gradient()

    def start_tracking_gradient(self):
        return self.linked_tensor_torch.start_tracking_gradient()

    @staticmethod
    def get_description():
        return "Loader for single image"
Beispiel #2
0
class ConstantConstant_Torch(input.ConstantConstant):
    def __init__(self,
                 view: list,
                 value: int,
                 name: str = "ConstantConstant_Torch",
                 device=torch.device("cpu")):
        super(ConstantConstant_Torch, self).__init__(name)
        self.linked_tensor_torch = Tensor_Torch(
            torch.add(torch.zeros(*view, device=device), value),
            name=self.name + "const_tensor")

    def get_saved_tensor(self):
        return self.linked_tensor_torch

    def set_device(self, device: torch.device):
        self.linked_tensor_torch.set_device(device=device)

    def get_device(self):
        return self.linked_tensor_torch.get_device()

    # return KB in memory usage for the loaded tensor
    def get_tensor_memory_size(self):
        return self.linked_tensor_torch.get_self_memory_size()

    # return KB in memory usage for gradients of the loaded tensor
    def get_tensor_grad_memory_size(self):
        return self.linked_tensor_torch.get_grad_memory_size()

    def remove_from_tracking_gradient(self):
        return self.linked_tensor_torch.remove_from_tracking_gradient()

    def start_tracking_gradient(self):
        return self.linked_tensor_torch.start_tracking_gradient()

    @staticmethod
    def get_description():
        return "Constant tensor constant (1, 0)"
Beispiel #3
0
class TensorConstant_Torch(input.TensorConstant):
    def __init__(self,
                 tensor_path: str,
                 name: str = "TensorConstant_Torch",
                 device=torch.device("cpu")):
        super(TensorConstant_Torch, self).__init__(name)
        self.linked_tensor_torch = Tensor_Torch(
            torch.load(tensor_path).to(device),
            name=self.name + "_saved_tensor")

    def get_saved_tensor(self):
        return self.linked_tensor_torch

    def set_device(self, device: torch.device):
        self.linked_tensor_torch.set_device(device=device)

    def get_device(self):
        return self.linked_tensor_torch.get_device()

    # return KB in memory usage for the loaded tensor
    def get_tensor_memory_size(self):
        return self.linked_tensor_torch.get_self_memory_size()

    # return KB in memory usage for gradients of the loaded tensor
    def get_tensor_grad_memory_size(self):
        return self.linked_tensor_torch.get_grad_memory_size()

    def remove_from_tracking_gradient(self):
        return self.linked_tensor_torch.remove_from_tracking_gradient()

    def start_tracking_gradient(self):
        return self.linked_tensor_torch.start_tracking_gradient()

    @staticmethod
    def get_description():
        return "Constant Tensor constant"