예제 #1
0
파일: div2k.py 프로젝트: zack466/autoreg-sr
    def cache_func(self, i):
        # caches the ith chunk of images
        # custom function for using HDF5Cache
        lr_images = []
        hr_images = []
        offset = i * self.cache.cache_size // self.mult
        for idx in range(self.cache.cache_size // self.mult):
            if offset + idx + 1 > 800:
                idx -= self.cache.cache_size
            img_hr_name = ("./datasets/saved/DIV2K_train_HR/" +
                           str(offset + idx + 1).zfill(4) + ".png")
            img_lr_name = (
                f"./datasets/saved/DIV2K_train_LR_bicubic/X{self.factor}/" +
                str(offset + idx + 1).zfill(4) + f"x{self.factor}.png")
            # C,H,W
            img_hr = Image.open(img_hr_name)
            img_lr = Image.open(img_lr_name)

            hr_size = self.size * self.factor
            f = self.factor
            for j in range(self.mult):
                ii, j, k, l = RandomCrop.get_params(
                    img_hr,
                    (hr_size, hr_size))  # can't use i as variable name :/
                hr_crop = TF.crop(img_hr, ii, j, k, l)
                lr_crop = TF.crop(img_lr, ii // f, j // f, k // f, l // f)
                lr_images.append(ToTensor()(lr_crop))
                hr_images.append(ToTensor()(hr_crop))

        lr_stacked = np.stack(lr_images)
        hr_stacked = np.stack(hr_images)
        # print(lr_stacked.shape)
        lr_type = lr_stacked.astype(np.float32)
        hr_type = hr_stacked.astype(np.float32)
        self.cache.cache_images(i, lr_type, hr_type)
예제 #2
0
def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        # Resize((128,128), interpolation=Image.BICUBIC),
        ToTensor()
        # Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
예제 #3
0
def transform_target(crop_size):
    """Ground truth image
    """
    return Compose([
        RandomCrop(crop_size),
        RandomHorizontalFlip(),
        ])
예제 #4
0
 def img_transform(self, crop_size):
     normalize = Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
     return Compose(
         [Resize([512, 384]),
          RandomCrop(crop_size),
          ToTensor(), normalize])
예제 #5
0
  def __init__(self, opts):
    self.dataroot = opts.dataroot
    self.num_domains = opts.num_domains
    self.input_dim = opts.input_dim
    self.nz = opts.input_nz

    domains = [chr(i) for i in range(ord('A'),ord('Z')+1)]
    self.images = [None]*self.num_domains
    stats = ''
    for i in range(self.num_domains):
      img_dir = os.path.join(self.dataroot, opts.phase + domains[i])
      ilist = os.listdir(img_dir)
      self.images[i] = [os.path.join(img_dir, x) for x in ilist]
      stats += '{}: {}'.format(domains[i], len(self.images[i]))
    stats += ' images'
    self.dataset_size = max([len(self.images[i]) for i in range(self.num_domains)])

    # setup image transformation
    transforms = [Resize((opts.resize_size, opts.resize_size), Image.BICUBIC)]
    if opts.phase == 'train':
      transforms.append(RandomCrop(opts.img_size))
    else:
      transforms.append(CenterCrop(opts.img_size))
    if not opts.no_flip:
      transforms.append(RandomHorizontalFlip())
    transforms.append(ToTensor())
    transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
    self.transforms = Compose(transforms)

    return
예제 #6
0
    def transform_train(self, data, targets=None, batch_size=None):
        short_edge_length = min(data.shape[1], data.shape[2])
        common_list = [
            Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))
        ]
        if self.augment:
            compose_list = [
                ToPILImage(),
                RandomCrop(data.shape[1:3], padding=4),
                RandomHorizontalFlip(),
                ToTensor()
            ] + common_list + [
                Cutout(n_holes=Constant.CUTOUT_HOLES,
                       length=int(short_edge_length * Constant.CUTOUT_RATIO))
            ]
        else:
            compose_list = common_list

        dataset = self._transform(compose_list, data, targets)

        if batch_size is None:
            batch_size = Constant.MAX_BATCH_SIZE
        batch_size = min(len(data), batch_size)

        return DataLoader(dataset, batch_size=batch_size, shuffle=True)
예제 #7
0
    def __init__(self, opt, val=False):
        super(CustomCIFAR100, self).__init__()
        dir_dataset = opt.dir_dataset

        if val:
            self.dataset = CIFAR100(root=dir_dataset,
                                    train=False,
                                    download=True)
            self.transform = Compose([
                ToTensor(),
                Normalize(mean=[0.507, 0.487, 0.441],
                          std=[0.267, 0.256, 0.276])
            ])

        else:
            self.dataset = CIFAR100(root=dir_dataset,
                                    train=True,
                                    download=True)
            self.transform = Compose([
                RandomCrop((32, 32),
                           padding=4,
                           fill=0,
                           padding_mode='constant'),
                RandomHorizontalFlip(),
                ToTensor(),
                Normalize(mean=[0.507, 0.487, 0.441],
                          std=[0.267, 0.256, 0.276])
            ])
예제 #8
0
    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])  #input是预先合成的4通道RGDB图片
        #数据增强
        if self.crop:
            input = RandomCrop(64)(input)  #取patch
            input = RandomHorizontalFlip()(input)  #水平翻转
            input = RandomVerticalFlip()(input)  #竖直翻转
            input = RandomRotation(180)(input)  #随机旋转
        input_tensor = ToTensor()(input)
        rgb_tensor = torch.zeros(3, input_tensor.shape[1],
                                 input_tensor.shape[2])
        depth_tensor = torch.zeros(1, input_tensor.shape[1],
                                   input_tensor.shape[2])
        rgb_tensor[0, :, :] = input_tensor[0, :, :]
        rgb_tensor[1, :, :] = input_tensor[1, :, :]
        rgb_tensor[2, :, :] = input_tensor[2, :, :]
        depth_tensor[0, :, :] = input_tensor[3, :, :]
        depth = ToPILImage()(depth_tensor)
        size = min(depth.size[0], depth.size[1])
        guide = ToPILImage()(rgb_tensor)
        target = depth.copy()

        guide = guide.convert('L')
        #生成LR
        depth = downsampling(depth, self.upscale_factor)
        depth = Resize(size=size, interpolation=Image.BICUBIC)(depth)

        depth = ToTensor()(depth)
        guide = ToTensor()(guide)
        depth = torch.cat((depth, guide), 0)  #concatenate 生成输入张量
        target = ToTensor()(target)

        return depth, target
예제 #9
0
    def transform_d(self, image, mask):
        """
        Random crop + Random horizontal flipping + Random vertical flipping
        """
        # Random crop
        i, j, h, w = RandomCrop.get_params(image,
                                           output_size=(self.crop_size,
                                                        self.crop_size))
        image = TF.crop(image, i, j, h, w)
        mask = TF.crop(mask, i, j, h, w)

        # resize
        image = image
        image = TF.resize(image, (self.crop_size // self.upscale_factor,
                                  self.crop_size // self.upscale_factor))

        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)

        # Random vertical flipping
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)

        # Transform to tensor
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        return image, mask
예제 #10
0
def get_transforms(augment):
    valid_t = Compose([Resize(256),
                       CenterCrop(224),
                       ToTensor(),
                       Normalize(**_ImageNet['Normalize'])])
    if augment == False:
        train_t = valid_t
    elif augment == True:
        train_t = Compose([RandomResizedCrop(224),
                           RandomHorizontalFlip(),
                           ToTensor(),
                           ColorJitter(),
                           Lighting(_ImageNet['PCA']),
                           Normalize(**_ImageNet['Normalize'])])
    elif augment == "torchvision": 
        train_t = Compose([RandomResizedCrop(224),
                           RandomHorizontalFlip(),
                           ToTensor(),
                           Normalize(**_ImageNet['Normalize'])])
    elif augment == "torchvision2": 
        train_t = Compose([Resize(256),
                           RandomCrop(224),
                           RandomHorizontalFlip(),
                           ToTensor(),
                           Normalize(**_ImageNet['Normalize'])])
    else:
        assert(False)
        
    transforms = {
        'training':   train_t,
        'validation': valid_t
    }
    return transforms
예제 #11
0
    def transform_train(self, data, targets=None, batch_size=None):
        """ Transform the training data, perform random cropping data augmentation and basic random flip augmentation.

        Args:
            batch_size: int batch_size.
            targets: the target of training set.

        Returns:
            A DataLoader class instance.
        """
        short_edge_length = min(data.shape[1], data.shape[2])
        common_list = [Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))]
        if self.augment:
            compose_list = [ToPILImage(),
                            RandomCrop(data.shape[1:3], padding=4),
                            RandomHorizontalFlip(),
                            ToTensor()
                            ] + common_list + [Cutout(n_holes=Constant.CUTOUT_HOLES,
                                                      length=int(short_edge_length * Constant.CUTOUT_RATIO))]
        else:
            compose_list = common_list

        if len(data.shape) != 4:
            compose_list = []

        dataset = self._transform(compose_list, data, targets)

        if batch_size is None:
            batch_size = Constant.MAX_BATCH_SIZE
        batch_size = min(len(data), batch_size)

        return DataLoader(dataset, batch_size=batch_size, shuffle=True)
예제 #12
0
def image_crop_rescale(sample, crop_size, color_mode):
    image = ToPILImage(mode=color_mode)(sample)
    cropped_image = RandomCrop(crop_size)(image)
    rescaled_image = Resize((sample.shape[1], sample.shape[2]),
                            interpolation=0)(cropped_image)
    cropped_rescaled_sample = ToTensor()(rescaled_image)
    return cropped_rescaled_sample
예제 #13
0
def main(args):
    # --- CONFIG
    device = torch.device(f"cuda:{args.cuda}"
                          if torch.cuda.is_available() and
                          args.cuda >= 0 else "cpu")
    n_batches = 5
    # ---------

    # --- TRANSFORMATIONS
    train_transform = transforms.Compose([
        RandomCrop(28, padding=4),
        ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    test_transform = transforms.Compose([
        ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    # ---------

    # --- SCENARIO CREATION
    mnist_train = MNIST('./data/mnist', train=True,
                        download=True, transform=train_transform)
    mnist_test = MNIST('./data/mnist', train=False,
                       download=True, transform=test_transform)
    scenario = nc_scenario(
        mnist_train, mnist_test, n_batches, task_labels=False, seed=1234)
    # ---------

    # MODEL CREATION
    model = SimpleMLP(num_classes=scenario.n_classes)

    # choose some metrics and evaluation method
    interactive_logger = InteractiveLogger()

    eval_plugin = EvaluationPlugin(
        accuracy_metrics(
            minibatch=True, epoch=True, experience=True, stream=True),
        loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
        ExperienceForgetting(),
        loggers=[interactive_logger])

    # CREATE THE STRATEGY INSTANCE (NAIVE)
    cl_strategy = Naive(model, torch.optim.Adam(model.parameters(), lr=0.001),
                        CrossEntropyLoss(),
                        train_mb_size=100, train_epochs=4, eval_mb_size=100, device=device,
                        plugins=[ReplayPlugin(mem_size=10000)],
                        evaluator=eval_plugin
                        )

    # TRAINING LOOP
    print('Starting experiment...')
    results = []
    for experience in scenario.train_stream:
        print("Start of experience ", experience.current_experience)
        cl_strategy.train(experience)
        print('Training completed')

        print('Computing accuracy on the whole test set')
        results.append(cl_strategy.eval(scenario.test_stream))
예제 #14
0
	def __init__(self, seq_name, vis_threshold, P, K, max_per_person, crop_H, crop_W,
				transform, normalize_mean=None, normalize_std=None):

		self.data_dir = osp.join(cfg.DATA_DIR, 'cuhk03_release')
		self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat')
		if not osp.exists(self.raw_mat_path):
			raise RuntimeError("'{}' is not available".format(self.raw_mat_path))
		self.seq_name = seq_name

		self.P = P
		self.K = K
		self.max_per_person = max_per_person
		self.crop_H = crop_H
		self.crop_W = crop_W

		if transform == "random":
			self.transform = Compose([RandomCrop((crop_H, crop_W)), RandomHorizontalFlip(), ToTensor(), Normalize(normalize_mean, normalize_std)])
		elif transform == "center":
			self.transform = Compose([CenterCrop((crop_H, crop_W)), ToTensor(), Normalize(normalize_mean, normalize_std)])
		else:
			raise NotImplementedError("Tranformation not understood: {}".format(transform))

		if seq_name:
			assert seq_name in ['labeled', 'detected']
			self.data = self.load_images()
		else:
			self.data = []

		self.build_samples()
    def __getitem__(self, index):
        image = Image.open(os.path.join(
            self.root,
            self.list_paths[index]))  # Open image from the given path.

        # Get transform list
        list_transforms = list()
        if self.crop_size > 0:
            list_transforms.append(RandomCrop(
                (self.crop_size, self.crop_size)))
        if self.flip:
            coin = random.random() > 0.5
            if coin:
                list_transforms.append(RandomHorizontalFlip())
        transforms = Compose(list_transforms)

        image = transforms(image)  # Implement common transform

        input_image = Grayscale(num_output_channels=1)(
            image)  # For input image, we need to make it B/W.

        input_tensor, target_tensor = ToTensor()(input_image), ToTensor()(
            image)  # Make input, target as torch.Tensor

        input_tensor = Normalize(mean=[0.5], std=[0.5])(
            input_tensor)  # As the input tensor has only one channel,
        # Normalize parameters also have one value each.
        target_tensor = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(
            target_tensor)  # As the target tensor has
        # three channels Normalize parameters also have three values each.

        return input_tensor, target_tensor
예제 #16
0
    def __init__(self,
                 root,
                 list_file,
                 patch_size=96,
                 shrink_size=2,
                 noise_level=1,
                 down_sample_method=None,
                 transform=None):
        self.root = root
        self.transform = transform
        self.random_cropper = RandomCrop(size=patch_size)
        self.img_augmenter = ImageAugment(shrink_size, noise_level,
                                          down_sample_method)
        self.transform = transform
        self.fnames = []

        if isinstance(list_file, list):
            tmp_file = '/tmp/listfile.txt'
            os.system('cat %s > %s' % (' '.join(list_file), tmp_file))
            list_file = tmp_file

        with open(list_file) as f:
            lines = f.readlines()
            self.num_imgs = len(lines)

        for line in lines:
            self.fnames.append(line)
예제 #17
0
 def _video_transform(self, mode: str):
     """
     This function contains example transforms using both PyTorchVideo and TorchVision
     in the same Callable. For 'train' mode, we use augmentations (prepended with
     'Random'), for 'val' mode we use the respective determinstic function.
     """
     args = self.args
     return ApplyTransformToKey(
         key="video",
         transform=Compose(
             [
                 UniformTemporalSubsample(args.video_num_subsampled),
                 Normalize(args.video_means, args.video_stds),
             ]
             + (
                 [
                     RandomShortSideScale(
                         min_size=args.video_min_short_side_scale,
                         max_size=args.video_max_short_side_scale,
                     ),
                     RandomCrop(args.video_crop_size),
                     RandomHorizontalFlip(p=args.video_horizontal_flip_p),
                 ]
                 if mode == "train"
                 else [
                     ShortSideScale(args.video_min_short_side_scale),
                     CenterCrop(args.video_crop_size),
                 ]
             )
         ),
     )
예제 #18
0
def _get_headpose_dataset(dataset, opt, mean, std, attrs):
    train_transform = Compose(
        [Resize(240),
         RandomCrop(224),
         ToTensor(),
         Normalize(mean, std)])
    val_transform = Compose([
        Resize((opt.face_size, opt.face_size)),
        ToTensor(),
        Normalize(mean, std)
    ])
    target_transform = ToMaskedTargetTensor(attrs)

    if dataset == '300W_LP':
        root = os.path.join(opt.root_path, '300W_LP')
        train_data = Pose_300W_LP(root,
                                  transform=train_transform,
                                  target_transform=target_transform,
                                  training=True)
        val_data = Pose_300W_LP(root,
                                transform=val_transform,
                                target_transform=target_transform,
                                training=False)
    else:
        raise Exception('Error: not a valid dataset name')

    return split_dataset_into_train_val(train_data, val_data, val_ratio=0.05)
예제 #19
0
def get_train_test_loaders(dataset_name, path, batch_size, num_workers):

    assert dataset_name in datasets.__dict__, "Unknown dataset name {}".format(dataset_name)
    fn = datasets.__dict__[dataset_name]

    train_transform = Compose([
        Pad(2),
        RandomCrop(32),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25)),
    ])

    test_transform = Compose([
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25)),
    ])

    train_ds = fn(root=path, train=True, transform=train_transform, download=True)
    test_ds = fn(root=path, train=False, transform=test_transform, download=False)

    train_loader = DataLoader(train_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size * 2, num_workers=num_workers, pin_memory=True)

    return train_loader, test_loader
예제 #20
0
 def __init__(self, image_dir, patch_size, scale_factor, data_augmentation=True):
     super(DatasetFromFolder, self).__init__()
     self.filenames = [str(filename) for filename in Path(image_dir).glob('*') if filename.suffix in ['.bmp', '.jpg', '.png']]
     self.patch_size = patch_size
     self.scale_factor = scale_factor
     self.data_augmentation = data_augmentation
     self.crop = RandomCrop(self.patch_size)
예제 #21
0
    def prepare_data(self):
        "Prepare supervised and unsupervised datasets from cifar"
        dataset_path = self.hparams.dataset_path
        n_labeled = self.hparams.n_labeled
        n_overlap = self.hparams.n_overlap
        seed = self.hparams.seed

        if self.hparams.dataset == "cifar":
            n = self.hparams.randaug_n
            m = self.hparams.randaug_m
            if self.hparams.strong_tfm:
                train_tfm = Compose([RandAugment(n, m), ToTensor(), Normalize(*CIFAR_STATS)])
            else:
                train_tfm = Compose([RandomCrop(32, 4, padding_mode="reflect"), RandomHorizontalFlip(), ToTensor(), Normalize(*CIFAR_STATS)])
            valid_tfm = Compose([ToTensor(), Normalize(*CIFAR_STATS)])
            sup_ds, unsup_ds = Cifar.uda_ds(dataset_path, n_labeled, n_overlap, train_tfm, seed=seed)
            val_ds = Cifar.val_ds(dataset_path, valid_tfm)

        if self.hparams.dataset == "quickdraw":
            train_tfm = Compose([ExpandChannels, SketchDeformation, RandomHorizontalFlip(), RandomRotation(30), RandomCrop(128, 18), ToTensor()])
            valid_tfm = Compose([ExpandChannels, ToTensor()])
            sup_ds, unsup_ds = QuickDraw.uda_ds(dataset_path, n_labeled, n_overlap, train_tfm, seed=seed)
            val_ds = QuickDraw.val_ds(dataset_path, valid_tfm)

        self.train_ds = sup_ds
        self.valid_ds = val_ds
        print("Loaded {} train examples and {} validation examples".format(len(self.train_ds), len(self.valid_ds)))
예제 #22
0
def _input_transform(crop_size, upscale_factor, patch_size=None):
    return Compose([
        CenterCrop(crop_size),
        Resize(crop_size // upscale_factor),
        RandomCrop(patch_size if patch_size is not None else crop_size // upscale_factor),
        ToTensor(),
    ])
예제 #23
0
def ImageTransform(loadSize, cropSize):
    return Compose([
        Resize(size=loadSize, interpolation=Image.BICUBIC),
        RandomCrop(size=cropSize),
        RandomHorizontalFlip(p=0.5),
        ToTensor(),
    ])
예제 #24
0
def get_train_transform(length=T):
    trans_list = [ToPILImage(),
                  Pad((length // 2, 0)),
                  RandomCrop((1, length)),
                  ToTensor(),
                  Centring(MAX_INT)]
    return transforms.Compose([ConvertToTuple(default_transforms) for default_transforms in trans_list])
예제 #25
0
	def __init__(self, seq_name, vis_threshold, P, K, max_per_person, crop_H, crop_W,
				transform, normalize_mean=None, normalize_std=None):

		self.data_dir = osp.join(cfg.DATA_DIR, 'Market-1501-v15.09.15')
		self.seq_name = seq_name

		self.P = P
		self.K = K
		self.max_per_person = max_per_person
		self.crop_H = crop_H
		self.crop_W = crop_W

		if transform == "random":
			self.transform = Compose([RandomCrop((crop_H, crop_W)), RandomHorizontalFlip(), ToTensor(), Normalize(normalize_mean, normalize_std)])
		elif transform == "center":
			self.transform = Compose([CenterCrop((crop_H, crop_W)), ToTensor(), Normalize(normalize_mean, normalize_std)])
		else:
			raise NotImplementedError("Tranformation not understood: {}".format(transform))

		if seq_name:
			assert seq_name in ['bounding_box_test', 'bounding_box_train', 'gt_bbox'], \
				'Image set does not exist: {}'.format(seq_name)
			self.data = self.load_images()
		else:
			self.data = []

		self.build_samples()
예제 #26
0
def HR_8_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        RandomScale(),
        #RandomRotate(),
        RandomHorizontalFlip(),
    ])
예제 #27
0
def build_dataset(source_domain_name,
                  target_domain_name):
    """ Build torch DataSet

    Args:
        source_domain_name (string): name of source domain dataset.
        target_domain_name (string): name of target domain dataset.

    Returns:
        datasets (dict): dictionary mapping domain_name (string) to torch Dataset.
    """
    # Define transforms for training and evaluation
    transform_train = Compose([Resize([256, 256]),
                               RandomCrop([224, 224]),
                               RandomHorizontalFlip(),
                               RandomRotation(degrees=30, fill=128),
                               ToTensor(),
                               Normalize(IMAGENET_MEAN, IMAGENET_STD)])
    transform_eval = Compose([Resize([256, 256]),
                              CenterCrop([224, 224]),
                              ToTensor(),
                              Normalize(IMAGENET_MEAN, IMAGENET_STD)])

    datasets = {}
    datasets['train_source'] = ImageFolder(root=root_dir[source_domain_name],
                                           transform=transform_train)
    datasets['train_target'] = ImageFolder(root=root_dir[target_domain_name],
                                           transform=transform_train)
    datasets['test'] = ImageFolder(root=root_dir[target_domain_name],
                                   transform=transform_eval)
    return datasets
예제 #28
0
    def default_transforms(self) -> Dict[str, Callable]:
        if self.training:
            post_tensor_transform = [
                RandomShortSideScale(min_size=256, max_size=320),
                RandomCrop(244),
                RandomHorizontalFlip(p=0.5),
            ]
        else:
            post_tensor_transform = [
                ShortSideScale(256),
            ]

        return {
            "post_tensor_transform":
            Compose([
                ApplyTransformToKey(
                    key="video",
                    transform=Compose([UniformTemporalSubsample(8)] +
                                      post_tensor_transform),
                ),
            ]),
            "per_batch_transform_on_device":
            Compose([
                ApplyTransformToKey(
                    key="video",
                    transform=K.VideoSequential(K.Normalize(
                        torch.tensor([0.45, 0.45, 0.45]),
                        torch.tensor([0.225, 0.225, 0.225])),
                                                data_format="BCTHW",
                                                same_on_frame=False)),
            ]),
        }
예제 #29
0
파일: dataset.py 프로젝트: samxuxiang/mcmi
  def __init__(self, opts):
    self.dataroot = opts.dataroot

    # A
    images_A = os.listdir(os.path.join(self.dataroot, opts.phase + 'A'))
    self.A = [os.path.join(self.dataroot, opts.phase + 'A', x) for x in images_A]

    # B
    images_B = os.listdir(os.path.join(self.dataroot, opts.phase + 'B'))
    self.B = [os.path.join(self.dataroot, opts.phase + 'B', x) for x in images_B]

    self.A_size = len(self.A)
    self.B_size = len(self.B)
    self.dataset_size = max(self.A_size, self.B_size)
    self.input_dim_A = opts.input_dim_a
    self.input_dim_B = opts.input_dim_b

    # setup image transformation
    transforms = [Resize((opts.resize_size, opts.resize_size), Image.BICUBIC)]
    if opts.phase == 'train':
      transforms.append(RandomCrop(opts.crop_size))
    else:
      transforms.append(CenterCrop(opts.crop_size))
    if not opts.no_flip:
      transforms.append(RandomHorizontalFlip())
    transforms.append(ToTensor())
    transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
    self.transforms = Compose(transforms)
    print('A: %d, B: %d images'%(self.A_size, self.B_size))
    return
예제 #30
0
    def test_twice_transform(self):
        from torchvision.transforms import Compose, RandomCrop, RandomRotation, ColorJitter, ToTensor
        transforms = SequentialWrapperTwice(com_transform=Compose(
            [RandomRotation(45), RandomCrop(224)], ),
                                            image_transform=Compose([
                                                ColorJitter(
                                                    brightness=[0.8, 1.2],
                                                    contrast=[0.8, 1.2],
                                                    saturation=1),
                                                ToTensor()
                                            ]),
                                            target_transform=ToLabel(),
                                            total_freedom=False)

        dataset = ACDCDataset(
            root_dir=self._root,
            mode="train",
            transforms=transforms,
        )
        (image1, image2, target1, target2), filename = dataset[4]
        from deepclustering3.viewer import multi_slice_viewer_debug
        import matplotlib.pyplot as plt
        multi_slice_viewer_debug(torch.cat([image1, image2], dim=0),
                                 torch.cat([target1, target2], dim=0),
                                 no_contour=True)
        plt.show()