def main(): parser = argparse.ArgumentParser(description='Train WSISR on compressed TMA dataset') parser.add_argument('--batch-size', default=32, type=int, help='Batch size') parser.add_argument('--patch-size', default=224, type=int, help='Patch size') parser.add_argument('--up-scale', default=5, type=float, help='Targeted upscale factor') parser.add_argument('--num-workers', default=1, type=int, help='Number of workers') parser.add_argument('--num-epochs', default=900, type=int, help='Number of epochs, more epochs are desired for GAN training') parser.add_argument('--g-lr', default=0.0001, type=float, help='Learning rate of the generator') parser.add_argument('--d-lr', default=0.00001, type=float, help='Learning rate of the descriminator') parser.add_argument('--percep-weight', default=0.01, type=float, help='GAN loss weight') parser.add_argument('--run-from', default=None, type=str, help='Load weights from a previous run, use folder name in [weights] folder') parser.add_argument('--start-epoch', default=1, type=int, help='Starting epoch for the curriculum, start at 1/2 of the epochs to skip the curriculum') parser.add_argument('--gan', default=1, type=int, help='Use GAN') parser.add_argument('--num-critic', default=1, type=int, help='Interval of training the descriminator') args = parser.parse_args() warnings.filterwarnings('ignore') device = torch.device('cuda:0') tensor = torch.cuda.FloatTensor data.generate_compress_csv() valid_dataset = new_compress_curriculum(args, args.up_scale, 'valid') generator = models.Generator() generator.to(device); discriminator = models.Discriminator() discriminator.to(device); criterionL = nn.L1Loss().cuda() criterionMSE = nn.MSELoss().cuda() optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.g_lr) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.d_lr) patch = (1, args.patch_size // 2 ** 4, args.patch_size // 2 ** 4) if args.run_from is not None: generator.load_state_dict(torch.load(os.path.join('weights', args.run_from, 'generator.pth'))) try: discriminator.load_state_dict(torch.load(os.path.join('weights', args.run_from, 'discriminator.pth'))) except: print('Discriminator weights not found!') pass optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.g_lr) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.d_lr) scheduler_G = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_G, args.num_epochs, args.g_lr*0.05) scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_D, args.num_epochs, args.d_lr*0.05) run = datetime.now().strftime("%Y-%m-%d--%H-%M-%S") cur_length = int(0.5*args.num_epochs) init_scale = 2**2 step_size = (2**args.up_scale-init_scale) / cur_length for epoch in range(args.start_epoch, args.num_epochs): factor = min(log2(init_scale+(epoch-1)*step_size), args.up_scale) print('curriculum updated: {} '.format(factor)) train_dataset = new_compress_curriculum(args, factor, 'train', stc=True) train(args, epoch, run, train_dataset, generator, discriminator, optimizer_G, optimizer_D, criterionL, criterionMSE, tensor, device, patch) scheduler_G.step() scheduler_D.step() if epoch % 1 == 0: fid, psnr = test(args, generator, data.compress_csv_path('valid')) print_output(generator, valid_dataset, device) print('\r>>>> PSNR: {}, FID: {}'.format(psnr, fid)) test(args, generator, data.compress_csv_path('valid'), stitching=True)
def new_compress_curriculum(args, cur_factor, csv='train', stc=False): transformed_dataset = data.Compress_Dataset(csv_file=data.compress_csv_path(csv), transform=data.Compose([ transforms.RandomCrop((args.patch_size, args.patch_size)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), data.Rescale((args.patch_size, args.patch_size), up_factor=cur_factor, stc=stc), data.ToTensor() ])) dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) return dataloader
def train(self): cur_length = int(0.5 * self.num_epochs) init_scale = 2 ** 2 step_size = (2 ** self.up_scale - init_scale) / cur_length for epoch in range(self.start_epoch, self.num_epochs): factor = min(log2(init_scale + (epoch - 1) * step_size), self.up_scale) print('curriculum updated: {} '.format(factor)) train_dataset = get_dataloader(self.num_workers, self.batch_size, self.patch_size, factor, 'train', stc=True) self.epoch_train(train_dataset, epoch) self.scheduler_G.step() self.scheduler_D.step() if epoch % 1 == 0: fid, psnr = self.tester.test(self.generator, data.compress_csv_path('valid')) # print_output(generator, valid_dataset, device) print('\r>>>> PSNR: {}, FID: {}'.format(psnr, fid))