def get_eval_loader(root, img_size=256, batch_size=32, imagenet_normalize=True, shuffle=True, num_workers=4, drop_last=False): print('Preparing DataLoader for the evaluation phase...') if imagenet_normalize: height, width = 299, 299 mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] else: height, width = img_size, img_size mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] transform = transforms.Compose([ transforms.Resize([img_size, img_size]), transforms.Resize([height, width]), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) dataset = DefaultDataset(root, transform=transform) return data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=drop_last)
def get_test_loader(root, img_size=256, batch_size=32, shuffle=True, num_workers=4): print('Preparing DataLoader for the generation phase...') transform = transforms.Compose([ transforms.Resize([img_size, img_size]), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) dataset = ImageFolder(root, transform=transform) return data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
def get_train_loader(root, which='source', img_size=256, batch_size=8, prob=0.5, num_workers=4): print('Preparing DataLoader to fetch %s images ' 'during the training phase...' % which) crop = transforms.RandomResizedCrop(img_size, scale=[0.8, 1.0], ratio=[0.9, 1.1]) rand_crop = transforms.Lambda(lambda x: crop(x) if random.random() < prob else x) transform = transforms.Compose([ rand_crop, transforms.Resize([img_size, img_size]), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) if which == 'source': dataset = ImageFolder(root, transform=transform) elif which == 'reference': dataset = ReferenceDataset(root, transform) else: raise NotImplementedError sampler = _make_balanced_sampler(dataset.targets, batch_size) if sampler is not None: return data.DataLoader(dataset=dataset, batch_size=1, batch_sampler=sampler, num_workers=num_workers, pin_memory=False) else: return data.DataLoader(dataset=dataset, batch_size=batch_size, batch_sampler=None, num_workers=num_workers, pin_memory=False, drop_last=True)
def build_model(self): """ DataLoader """ train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize((self.img_size + 30, self.img_size + 30)), transforms.RandomCrop(self.img_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) test_transform = transforms.Compose([ transforms.Resize((self.img_size, self.img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) self.trainA = ImageFolder( os.path.join('dataset', self.dataset, 'trainA'), train_transform) self.trainB = ImageFolder( os.path.join('dataset', self.dataset, 'trainB'), train_transform) self.testA = ImageFolder( os.path.join('dataset', self.dataset, 'testA'), test_transform) self.testB = ImageFolder( os.path.join('dataset', self.dataset, 'testB'), test_transform) if self.is_parallel: from paddorch.utils.data.sampler import DistributedBatchSampler batch_sampler_trainA = DistributedBatchSampler( self.trainA, self.batch_size) batch_sampler_trainB = DistributedBatchSampler( self.trainB, self.batch_size) batch_sampler_testA = DistributedBatchSampler( self.testA, self.batch_size) batch_sampler_testB = DistributedBatchSampler( self.testB, self.batch_size) self.trainA_loader = DataLoader(self.trainA, batch_size=1, batch_sampler=batch_sampler_trainA) self.trainB_loader = DataLoader(self.trainB, batch_size=1, batch_sampler=batch_sampler_trainB) self.testA_loader = DataLoader(self.testA, batch_size=1, batch_sampler=batch_sampler_testA) self.testB_loader = DataLoader(self.testB, batch_size=1, batch_sampler=batch_sampler_testB) else: self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True) self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True) self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False) self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False) """ Define Generator, Discriminator """ self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light) self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch, n_blocks=self.n_res, img_size=self.img_size, light=self.light) self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7) self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7) self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5) self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5) if self.is_parallel: from paddle.fluid.dygraph import DataParallel self.genA2B = DataParallel(self.genA2B, self.strategy) self.genB2A = DataParallel(self.genB2A, self.strategy) self.disGA = DataParallel(self.disGA, self.strategy) self.disGB = DataParallel(self.disGB, self.strategy) self.disLA = DataParallel(self.disLA, self.strategy) self.disLB = DataParallel(self.disLB, self.strategy) """ Define Loss """ self.L1_loss = nn.L1Loss() self.MSE_loss = nn.MSELoss() self.BCE_loss = nn.BCEWithLogitsLoss() """ Trainer """ self.G_optim = torch.optim.Adam(self.genA2B.parameters() + self.genB2A.parameters(), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) self.D_optim = torch.optim.Adam( self.disGA.parameters() + self.disGB.parameters() + self.disLA.parameters() + self.disLB.parameters(), lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay) # self.G_optim = fluid.contrib.mixed_precision.decorator.decorate(self.G_optim) # self.D_optim = fluid.contrib.mixed_precision.decorator.decorate(self.D_optim) """ Define Rho clipper to constraint the value of rho in AdaILN and ILN""" self.Rho_clipper = RhoClipper(0, 1)