def __init__(self, transfer, descriptor, optimizer, wh, ctx, logger): self._init_metric() self.transfer = transfer self.descriptor = descriptor self.optimizer = optimizer self.gram_targets = [] self.wh = wh # losses self.content_loss = CustomMSE() self.style_loss = CustomSSE() self.tv_loss = TVLoss() self.frame_coherent_loss = LuminanceLoss() self.fm_coherent_loss = FeatureLoss() self.ctx = ctx self.logger = logger self.train_tic = timer() return
def train(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform = transforms.Compose( [crop(args.scale, args.patch_size), augmentation()]) dataset = mydata(GT_path=args.GT_path, LR_path=args.LR_path, in_memory=args.in_memory, transform=transform) loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) generator = Generator(img_feat=3, n_feats=64, kernel_size=3, num_block=args.res_num, scale=args.scale) if args.fine_tuning: generator.load_state_dict(torch.load(args.generator_path)) print("pre-trained model is loaded") print("path : %s" % (args.generator_path)) generator = generator.to(device) generator.train() l2_loss = nn.MSELoss() g_optim = optim.Adam(generator.parameters(), lr=1e-4) pre_epoch = 0 fine_epoch = 0 #### Train using L2_loss while pre_epoch < args.pre_train_epoch: for i, tr_data in enumerate(loader): gt = tr_data['GT'].to(device) lr = tr_data['LR'].to(device) output, _ = generator(lr) loss = l2_loss(gt, output) g_optim.zero_grad() loss.backward() g_optim.step() pre_epoch += 1 if pre_epoch % 2 == 0: print(pre_epoch) print(loss.item()) print('=========') if pre_epoch % 800 == 0: torch.save(generator.state_dict(), './model/pre_trained_model_%03d.pt' % pre_epoch) #### Train using perceptual & adversarial loss vgg_net = vgg19().to(device) vgg_net = vgg_net.eval() discriminator = Discriminator(patch_size=args.patch_size * args.scale) discriminator = discriminator.to(device) discriminator.train() d_optim = optim.Adam(discriminator.parameters(), lr=1e-4) scheduler = optim.lr_scheduler.StepLR(g_optim, step_size=2000, gamma=0.1) VGG_loss = perceptual_loss(vgg_net) cross_ent = nn.BCELoss() tv_loss = TVLoss() real_label = torch.ones((args.batch_size, 1)).to(device) fake_label = torch.zeros((args.batch_size, 1)).to(device) while fine_epoch < args.fine_train_epoch: scheduler.step() for i, tr_data in enumerate(loader): gt = tr_data['GT'].to(device) lr = tr_data['LR'].to(device) ## Training Discriminator output, _ = generator(lr) fake_prob = discriminator(output) real_prob = discriminator(gt) d_loss_real = cross_ent(real_prob, real_label) d_loss_fake = cross_ent(fake_prob, fake_label) d_loss = d_loss_real + d_loss_fake g_optim.zero_grad() d_optim.zero_grad() d_loss.backward() d_optim.step() ## Training Generator output, _ = generator(lr) fake_prob = discriminator(output) _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer=args.feat_layer) L2_loss = l2_loss(output, gt) percep_loss = args.vgg_rescale_coeff * _percep_loss adversarial_loss = args.adv_coeff * cross_ent( fake_prob, real_label) total_variance_loss = args.tv_loss_coeff * tv_loss( args.vgg_rescale_coeff * (hr_feat - sr_feat)**2) g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss g_optim.zero_grad() d_optim.zero_grad() g_loss.backward() g_optim.step() fine_epoch += 1 if fine_epoch % 2 == 0: print(fine_epoch) print(g_loss.item()) print(d_loss.item()) print('=========') if fine_epoch % 500 == 0: torch.save(generator.state_dict(), './model/SRGAN_gene_%03d.pt' % fine_epoch) torch.save(discriminator.state_dict(), './model/SRGAN_discrim_%03d.pt' % fine_epoch)
def get_style_model_and_losses(self, model_dict, style_img, content_img): """ get the losses. """ self.cnn = copy.deepcopy(self.cnn) self.cnn = self.cnn.to(self.device) c_idx = 0 r_idx = 0 p_idx = 0 # do some normalization! # list of losses in layers: content_losses = [] style_losses = [] tv_losses = [] model = nn.Sequential() i = 0 tv_mod = TVLoss(1e-3) model.add_module(str(len(model)), tv_mod) tv_losses.append(tv_mod) for layer in self.cnn.children(): if isinstance(layer, nn.Conv2d): i += 1 name = model_dict['conv'][c_idx] c_idx += 1 elif isinstance(layer, nn.ReLU): name = model_dict['relu'][r_idx] r_idx += 1 layer = nn.ReLU(inplace=True) elif isinstance(layer, nn.MaxPool2d): name = model_dict['pool'][p_idx] # layer = nn.AvgPool2d(kernel_size=2, stride=2) p_idx += 1 elif isinstance(layer, nn.BatchNorm2d): name = f'bn_{i}' else: layer_name = layer.__class__.__name__ raise RuntimeError(f'Unrecognized Layer: {layer_name}') model.add_module(name, layer) if name in self.content_layers_default: content_loss = ContentLoss() model.add_module(f'content_loss_{i}', content_loss) content_losses.append(content_loss) if name in self.style_layers_default: # get feature maps from style style_loss = StyleLoss() model.add_module(f'style_loss_{i}', style_loss) style_losses.append(style_loss) # now we trim off the layers after the last content and style losses # if there is extra non-needed layers. for i in range(len(model) - 1, -1, -1): if isinstance(model[i], ContentLoss) or isinstance( model[i], StyleLoss): break model = model[:(i + 1)] # rip through the model, getting the REAL feature maps for style and content. for module in style_losses: module.mode = "capture" model(style_img) for module in style_losses: module.mode = "none" for module in content_losses: module.mode = "capture" model(content_img) for module in style_losses: module.mode = "loss" for module in content_losses: module.mode = "loss" return model, style_losses, content_losses, tv_losses