def get_loader(dataset, dataroot, origin_size, image_size, batch_size, workers, split='train', shuffle=True, seed=None): if dataset =="trans": from datasets.trans import trans as commonDataset import transforms.pix2pix as transforms elif dataset == 'folder': from torchvision.datasets.folder import ImageFolder as commonDataset import torchvision.transforms as transforms elif dataset == 'pix2pix': from datasets.pix2pix import pix2pix as commonDataset import transforms.pix2pix as transforms if dataset != "folder": # for training set if split == "train": transform = transforms.Compose([ transforms.Resize(origin_size), transforms.RandomCrop(image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) dataset = commonDataset(root=dataroot, transform=transform, seed=seed) # for validating set else: transform = transforms.Compose([ transforms.Resize(origin_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) dataset = commonDataset(root=dataroot, transform=transform, seed=seed) else: # for training set if split == "train": transform = transforms.Compose([ transforms.Resize(origin_size), transforms.RandomCrop(image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) dataset = commonDataset(root=dataroot, transform=transform, seed=seed) # for validating set else: transform = transforms.Compose([ transforms.Resize(origin_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) dataset = commonDataset(root=dataroot, transform=transform, seed=seed) assert dataset dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=workers) return dataloader
def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize, workers=4, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None): #import pdb; pdb.set_trace() if datasetName == 'pix2pix': from datasets.pix2pix import pix2pix as commonDataset import transforms.pix2pix as transforms elif datasetName == 'folder': from datasets.folder import ImageFolder as commonDataset import torchvision.transforms as transforms elif datasetName == 'classification': from datasets.classification import classification as commonDataset import torchvision.transforms as transforms elif datasetName == 'pix2pix_val': from datasets.pix2pix_val import pix2pix_val as commonDataset import torchvision.transforms as transforms elif datasetName == 'pix2pix_val2': from datasets.pix2pix_val2 import pix2pix_val as commonDataset # import torchvision.transforms as transforms import transforms.pix2pix as transforms if split == 'train': dataset = commonDataset(root=dataroot, transform=transforms.Compose([ transforms.Scale(originalSize), transforms.RandomCrop(imageSize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]), seed=seed) else: dataset = commonDataset(root=dataroot, transform=transforms.Compose([ transforms.Scale(originalSize), transforms.CenterCrop(imageSize), transforms.ToTensor(), transforms.Normalize(mean, std), ]), seed=seed) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=shuffle, num_workers=int(workers)) return dataloader
def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize=64, workers=4, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None, pre="", label_file=""): if datasetName == 'my_loader': from datasets.my_loader import my_loader as commonDataset import transforms.pix2pix as transforms if split == 'train': dataset = commonDataset(root=dataroot, transform=transforms.Compose([ transforms.Scale(originalSize), transforms.RandomCrop(imageSize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]), seed=seed, pre=pre, label_file=label_file) else: dataset = commonDataset(root=dataroot, transform=transforms.Compose([ transforms.Scale(originalSize), transforms.CenterCrop(imageSize), transforms.ToTensor(), transforms.Normalize(mean, std), ]), seed=seed, pre=pre, label_file=label_file) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=shuffle, num_workers=int(workers)) return dataloader
def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize=64, workers=4, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None): #import pdb; pdb.set_trace() if datasetName == 'pix2pix': # from datasets.pix2pix import pix2pix as commonDataset # import transforms.pix2pix as transforms from datasets.pix2pix import pix2pix as commonDataset import transforms.pix2pix as transforms elif datasetName == 'pix2pix_val': # from datasets.pix2pix_val import pix2pix_val as commonDataset # import transforms.pix2pix as transforms from datasets.pix2pix_val import pix2pix_val as commonDataset import transforms.pix2pix as transforms if datasetName == 'pix2pix_class': # from datasets.pix2pix import pix2pix as commonDataset # import transforms.pix2pix as transforms from datasets.pix2pix_class import pix2pix as commonDataset import transforms.pix2pix as transforms if split == 'no_change': dataset = commonDataset(root=dataroot, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std), ]), seed=seed) elif split == 'train': dataset = commonDataset(root=dataroot, transform=transforms.Compose([ transforms.Scale(originalSize), transforms.RandomCrop(imageSize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std), ]), seed=seed) else: dataset = commonDataset( root=dataroot, transform=transforms.Compose([ transforms.Scale(originalSize), # transforms.CenterCrop(imageSize), # DID-MDN 数据集 transforms.CenterCrop(imageSize), # IDC_GAN 数据集 transforms.ToTensor(), transforms.Normalize(mean, std), ]), seed=seed) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, shuffle=shuffle, num_workers=int(workers)) return dataloader