def initialize_model(self, path=None, train=True): self.G = RFRNet() self.optm_G = optim.Adam(self.G.parameters(), lr = 2e-4) if train: self.lossNet = VGG16FeatureExtractor() try: start_iter = load_ckpt(path, [('generator', self.G)], [('optimizer_G', self.optm_G)]) if train: self.optm_G = optim.Adam(self.G.parameters(), lr = 2e-4) print('Model Initialized, iter: ', start_iter) self.iter = start_iter except: print('No trained model, from start') self.iter = 0
class RFRNetModel(): def __init__(self): self.loss_weights = { "tv": 0.1, "style": 120, "perceptual": 0.05, "valid": 1, "hole": 6 } self.learning_rates = {"train": 2e-4, "finetune": 5e-5} self.save_freq = 10000 self.G = None self.lossNet = None self.optm_G = None self.device = None self.iter = None self.real_A = None self.real_B = None self.fake_B = None self.comp_B = None self.l1_loss_val = 0.0 def initialize_model(self, path=None, train=True): self.G = RFRNet() self.optm_G = optim.Adam(self.G.parameters(), lr=self.learning_rates["train"]) if train: self.lossNet = VGG16FeatureExtractor() try: start_iter = load_ckpt(path, [('generator', self.G)], [('optimizer_G', self.optm_G)]) if train: self.optm_G = optim.Adam(self.G.parameters(), lr=self.learning_rates["train"]) print('Model Initialized, iter: ', start_iter) self.iter = start_iter except: print('No trained model, from start') self.iter = 0 return self def cuda(self): if torch.cuda.is_available(): self.device = torch.device("cuda") print("Model moved to cuda") self.G.cuda() if self.lossNet is not None: self.lossNet.cuda() else: self.device = torch.device("cpu") return self def multi_gpu(self): print(f'Multi-GPU training with {torch.cuda.device_count()} GPUs') self.G = nn.DataParallel(self.G) if self.lossNet is not None: self.lossNet = nn.DataParallel(self.lossNet) return self def train(self, train_loader, save_path, finetune=False, iters=450000, fp16=False, multi_gpu=True): writer = SummaryWriter() self.G.train(finetune=finetune) # Overwrite optimizer with a lower lr if finetune: self.optm_G = optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), lr=self.learning_rates["finetune"]) self.fp16 = fp16 and GOT_AMP if self.fp16: self.G, self.optm_G = amp.initialize(self.G, self.optm_G, opt_level="O1") if self.lossNet is not None: self.lossNet = amp.initialize(self.lossNet, opt_level="O1") if multi_gpu: self.multi_gpu() print("Starting training from iteration: {:d}, finetuning: {}".format( self.iter, finetune)) s_time = time.time() while self.iter < iters: for items in train_loader: gt_images, masks = self.__cuda__(*items) masked_images = gt_images * masks self.forward(masked_images, masks, gt_images) self.update_parameters() for k, v in self.metrics["lossG"].items(): writer.add_scalar(f"lossG/{k}", v, global_step=self.iter) self.iter += 1 if self.iter % 200 == 0: e_time = time.time() int_time = e_time - s_time print("Iteration:%d, l1_loss:%.4f, time_taken:%.2f" % (self.iter, self.l1_loss_val / 50, int_time)) writer.add_images("real_A", self.real_A, global_step=self.iter) writer.add_images("mask", self.mask, global_step=self.iter) writer.add_images("real_B", self.real_B, global_step=self.iter) writer.add_images("fake_B", self.fake_B, global_step=self.iter) writer.add_images("comp_B", self.comp_B, global_step=self.iter) # Reset s_time = time.time() self.l1_loss_val = 0.0 if self.iter % self.save_freq == 0: if not os.path.exists('{:s}'.format(save_path)): os.makedirs('{:s}'.format(save_path)) save_ckpt( '{:s}/g_{:d}{}.pth'.format( save_path, self.iter, "_finetune" if finetune else ""), [('generator', self.G)], [('optimizer_G', self.optm_G)], self.iter) if not os.path.exists('{:s}'.format(save_path)): os.makedirs('{:s}'.format(save_path)) save_ckpt( '{:s}/g_{:s}{}.pth'.format(save_path, "final", "_finetune" if finetune else ""), [('generator', self.G)], [('optimizer_G', self.optm_G)], self.iter) def test(self, test_loader, result_save_path): self.G.eval() for para in self.G.parameters(): para.requires_grad = False count = 0 for items in test_loader: gt_images, masks = self.__cuda__(*items) masked_images = gt_images * masks # print(f">>> masks.shape: {masks.shape}") if masks.size(1) == 1: masks = torch.cat([masks] * 3, dim=1) fake_B, mask = self.G(masked_images, masks) comp_B = fake_B * (1 - masks) + gt_images * masks if not os.path.exists('{:s}/results'.format(result_save_path)): os.makedirs('{:s}/results'.format(result_save_path)) for k in range(comp_B.size(0)): count += 1 grid = make_grid(comp_B[k:k + 1]) file_path = '{:s}/results/img_{:d}.png'.format( result_save_path, count) save_image(grid, file_path) grid = make_grid(masked_images[k:k + 1] + 1 - masks[k:k + 1]) file_path = '{:s}/results/masked_img_{:d}.png'.format( result_save_path, count) save_image(grid, file_path) def forward(self, masked_image, mask, gt_image): self.real_A = masked_image self.real_B = gt_image self.mask = mask fake_B, _ = self.G(masked_image, mask) self.fake_B = fake_B self.comp_B = self.fake_B * (1 - mask) + self.real_B * mask def update_parameters(self): self.update_G() self.update_D() def update_G(self): self.optm_G.zero_grad() loss_G = self.get_g_loss() if self.fp16: with amp.scale_loss(loss_G, self.optm_G) as scaled_loss: scaled_loss.backward() else: loss_G.backward() self.optm_G.step() def update_D(self): return def get_g_loss(self): real_B = self.real_B fake_B = self.fake_B comp_B = self.comp_B real_B_feats = self.lossNet(real_B) fake_B_feats = self.lossNet(fake_B) comp_B_feats = self.lossNet(comp_B) tv_loss = self.TV_loss(comp_B * (1 - self.mask)) style_loss = self.style_loss(real_B_feats, fake_B_feats) \ + self.style_loss(real_B_feats, comp_B_feats) perceptual_loss = self.perceptual_loss(real_B_feats, fake_B_feats) \ + self.perceptual_loss(real_B_feats, comp_B_feats) valid_loss = self.l1_loss(real_B, fake_B, self.mask) hole_loss = self.l1_loss(real_B, fake_B, (1 - self.mask)) loss_G = (tv_loss * self.loss_weights["tv"] + style_loss * self.loss_weights["style"] + perceptual_loss * self.loss_weights["perceptual"] + valid_loss * self.loss_weights["valid"] + hole_loss * self.loss_weights["hole"]) self.l1_loss_val += valid_loss.detach() + hole_loss.detach() self.metrics = { "lossG": { "sum": loss_G.item(), "tv": tv_loss.item() * self.loss_weights["tv"], "style": style_loss.item() * self.loss_weights["style"], "perceptual": perceptual_loss.item() * self.loss_weights["perceptual"], "valid": valid_loss.item() * self.loss_weights["valid"], "hole": hole_loss.item() * self.loss_weights["hole"], } } print(f"#{self.iter:08d} - lossG: {self.metrics['lossG']}") return loss_G @staticmethod def l1_loss(f1, f2, mask=1): return torch.mean(torch.abs(f1 - f2) * mask) @staticmethod def style_loss(A_feats, B_feats): assert len(A_feats) == len(B_feats), \ "the length of two input feature maps lists should be the same" loss_value = 0.0 for i in range(len(A_feats)): A_feat = A_feats[i] B_feat = B_feats[i] # _, c, w, h = A_feat.size() # A_feat = A_feat.view(A_feat.size(0), # A_feat.size(1), # A_feat.size(2) * A_feat.size(3)) # B_feat = B_feat.view(B_feat.size(0), # B_feat.size(1), # B_feat.size(2) * B_feat.size(3)) # A_style = torch.matmul(A_feat, A_feat.transpose(2, 1)) # B_style = torch.matmul(B_feat, B_feat.transpose(2, 1)) # loss_value += torch.mean(torch.abs(A_style - B_style)/(c * w * h)) # Avoid underflow when using mixed precision training gram_A = gram_matrix(A_feat) gram_B = gram_matrix(B_feat) loss_value += torch.mean(torch.abs(gram_A - gram_B)) return loss_value @staticmethod def TV_loss(x): h_tv = torch.mean(torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])) w_tv = torch.mean(torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])) return h_tv + w_tv @staticmethod def perceptual_loss(A_feats, B_feats): assert len(A_feats) == len(B_feats), \ "the length of two input feature maps lists should be the same" loss_value = 0.0 for i in range(len(A_feats)): A_feat = A_feats[i] B_feat = B_feats[i] loss_value += torch.mean(torch.abs(A_feat - B_feat)) return loss_value def __cuda__(self, *args): return (item.to(self.device) for item in args)
class RFRNetModel(): def __init__(self): self.G = None self.lossNet = None self.iter = None self.optm_G = None self.device = None self.real_A = None self.real_B = None self.fake_B = None self.comp_B = None self.l1_loss_val = 0.0 self.visual_names = ['real_A', 'real_B', 'mask', 'fake_B', 'comp_B'] self.model_names = ['G'] self.visualizer = Visualizer() def initialize_model(self, path=None, train=True): self.G = RFRNet() self.optm_G = optim.Adam(self.G.parameters(), lr=2e-4) self.print_networks(False) if train: self.lossNet = VGG16FeatureExtractor() try: start_iter = load_ckpt(path, [('generator', self.G)], [('optimizer_G', self.optm_G)]) if train: self.optm_G = optim.Adam(self.G.parameters(), lr=2e-4) print('Model Initialized, iter: ', start_iter) self.iter = start_iter except: print('No trained model, from start') self.iter = 0 def cuda(self): if torch.cuda.is_available(): self.device = torch.device("cuda") print("Model moved to cuda") self.G.cuda() if self.lossNet is not None: self.lossNet.cuda() else: self.device = torch.device("cpu") def train(self, train_loader, save_path, finetune=False, iters=450000): # writer = SummaryWriter(log_dir="log_info") self.G.train(finetune=finetune) if finetune: self.optm_G = optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), lr=5e-5) print("Starting training from iteration:{:d}".format(self.iter)) s_time = time.time() while self.iter < iters: for items in train_loader: gt_images, masks = self.__cuda__(*items) masked_images = gt_images * masks self.forward(masked_images, masks, gt_images) self.update_parameters() self.iter += 1 if self.iter % 50 == 0: e_time = time.time() int_time = e_time - s_time print("Iteration:%d, l1_loss:%.4f, time_taken:%.2f" % (self.iter, self.l1_loss_val / 50, int_time)) s_time = time.time() self.l1_loss_val = 0.0 if self.iter % 40000 == 0: if not os.path.exists('{:s}'.format(save_path)): os.makedirs('{:s}'.format(save_path)) save_ckpt('{:s}/g_{:d}.pth'.format(save_path, self.iter), [('generator', self.G)], [('optimizer_G', self.optm_G)], self.iter) if self.iter % 200 == 0: self.visualizer.display_current_results( self.get_current_visuals(), self.iter, False) if not os.path.exists('{:s}'.format(save_path)): os.makedirs('{:s}'.format(save_path)) save_ckpt('{:s}/g_{:s}.pth'.format(save_path, "final"), [('generator', self.G)], [('optimizer_G', self.optm_G)], self.iter) def test(self, test_loader, result_save_path): self.G.eval() for para in self.G.parameters(): para.requires_grad = False count = 0 for items in test_loader: gt_images, masks = self.__cuda__(*items) masked_images = gt_images * masks masks = torch.cat([masks] * 3, dim=1) fake_B, mask = self.G(masked_images, masks) comp_B = fake_B * (1 - masks) + gt_images * masks if not os.path.exists('{:s}/results'.format(result_save_path)): os.makedirs('{:s}/results'.format(result_save_path)) for k in range(comp_B.size(0)): count += 1 grid = make_grid(comp_B[k:k + 1]) file_path = '{:s}/results/img_{:d}.png'.format( result_save_path, count) save_image(grid, file_path) grid = make_grid(masked_images[k:k + 1] + 1 - masks[k:k + 1]) file_path = '{:s}/results/masked_img_{:d}.png'.format( result_save_path, count) save_image(grid, file_path) def forward(self, masked_image, mask, gt_image): self.real_A = masked_image self.real_B = gt_image self.mask = mask fake_B, _ = self.G(masked_image, mask) self.fake_B = fake_B self.comp_B = self.fake_B * (1 - mask) + self.real_B * mask def update_parameters(self): self.update_G() self.update_D() def update_G(self): self.optm_G.zero_grad() loss_G = self.get_g_loss() loss_G.backward() self.optm_G.step() def update_D(self): return def get_g_loss(self): real_B = self.real_B fake_B = self.fake_B comp_B = self.comp_B real_B_feats = self.lossNet(real_B) fake_B_feats = self.lossNet(fake_B) comp_B_feats = self.lossNet(comp_B) tv_loss = self.TV_loss(comp_B * (1 - self.mask)) style_loss = self.style_loss(real_B_feats, fake_B_feats) + self.style_loss( real_B_feats, comp_B_feats) preceptual_loss = self.preceptual_loss( real_B_feats, fake_B_feats) + self.preceptual_loss( real_B_feats, comp_B_feats) valid_loss = self.l1_loss(real_B, fake_B, self.mask) hole_loss = self.l1_loss(real_B, fake_B, (1 - self.mask)) loss_G = (tv_loss * 0.1 + style_loss * 120 + preceptual_loss * 0.05 + valid_loss * 1 + hole_loss * 6) self.l1_loss_val += valid_loss.detach() + hole_loss.detach() return loss_G def l1_loss(self, f1, f2, mask=1): return torch.mean(torch.abs(f1 - f2) * mask) def style_loss(self, A_feats, B_feats): assert len(A_feats) == len( B_feats ), "the length of two input feature maps lists should be the same" loss_value = 0.0 for i in range(len(A_feats)): A_feat = A_feats[i] B_feat = B_feats[i] _, c, w, h = A_feat.size() A_feat = A_feat.view(A_feat.size(0), A_feat.size(1), A_feat.size(2) * A_feat.size(3)) B_feat = B_feat.view(B_feat.size(0), B_feat.size(1), B_feat.size(2) * B_feat.size(3)) A_style = torch.matmul(A_feat, A_feat.transpose(2, 1)) B_style = torch.matmul(B_feat, B_feat.transpose(2, 1)) loss_value += torch.mean( torch.abs(A_style - B_style) / (c * w * h)) return loss_value def TV_loss(self, x): h_x = x.size(2) w_x = x.size(3) h_tv = torch.mean(torch.abs(x[:, :, 1:, :] - x[:, :, :h_x - 1, :])) w_tv = torch.mean(torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x - 1])) return h_tv + w_tv def preceptual_loss(self, A_feats, B_feats): assert len(A_feats) == len( B_feats ), "the length of two input feature maps lists should be the same" loss_value = 0.0 for i in range(len(A_feats)): A_feat = A_feats[i] B_feat = B_feats[i] loss_value += torch.mean(torch.abs(A_feat - B_feat)) return loss_value def __cuda__(self, *args): return (item.to(self.device) for item in args) def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): attr = getattr(self, name) if isinstance(attr, list): for i in range(len(attr)): visual_ret[name + ':' + str(i)] = attr[i] else: visual_ret[name] = attr return visual_ret def print_networks(self, verbose): """Print the total number of parameters in the network and (if verbose) network architecture Parameters: verbose (bool) -- if verbose: print the network architecture """ print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self, name) num_params = 0 for param in net.parameters(): num_params += param.numel() if verbose: print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) print('-----------------------------------------------')
class RFRNetModel(): def __init__(self): self.G = None self.lossNet = None self.iter = None self.optm_G = None self.device = None self.real_A = None self.real_B = None self.fake_B = None self.comp_B = None self.l1_loss_val = 0.0 def initialize_model(self, path=None, train=True): self.G = RFRNet() self.optm_G = optim.Adam(self.G.parameters(), lr=2e-4) if train: self.lossNet = VGG16FeatureExtractor() try: start_iter = load_ckpt(path, [('generator', self.G)], [('optimizer_G', self.optm_G)]) if train: self.optm_G = optim.Adam(self.G.parameters(), lr=2e-4) print('Model Initialized, iter: ', start_iter) self.iter = start_iter except: print('No trained model, from start') self.iter = 0 def cuda(self): if torch.cuda.is_available(): self.device = torch.device("cuda") print("Model moved to cuda") self.G.cuda() if self.lossNet is not None: self.lossNet.cuda() else: self.device = torch.device("cpu") def train(self, train_loader, save_path, finetune=False, iters=450000, batch_size=6, batch_preload_count=1): # writer = SummaryWriter(log_dir="log_info") self.G.train(finetune=finetune) if finetune: self.optm_G = optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), lr=5e-5) print("Starting training from iteration:{:d}".format(self.iter)) s_time = time.time() while self.iter < iters: for items in train_loader: gt_image_batch, mask_batch, masked_image_batch = self.__cuda__( *items) # print("New batch of %s elements" %(items[0].size()[0])) for batch_idx in range(0, batch_preload_count): left = batch_idx * batch_size right = left + min(batch_size, gt_image_batch.size()[0]) gt_image = gt_image_batch[left:right] mask = mask_batch[left:right] masked_image = masked_image_batch[left:right] if gt_image.size()[0] == 0: break # print(len(train_loader), batch_idx, left, right, gt_image, mask, masked_image) self.forward(masked_image, mask, gt_image) self.update_parameters() self.iter += 1 if self.iter % 50 == 0: e_time = time.time() int_time = e_time - s_time print("Iteration:%d, l1_loss:%.4f, time_taken:%.2f" % (self.iter, self.l1_loss_val / 50, int_time)) s_time = time.time() self.l1_loss_val = 0.0 if self.iter % 40000 == 0: if not os.path.exists('{:s}'.format(save_path)): os.makedirs('{:s}'.format(save_path)) save_ckpt( '{:s}/g_{:d}.pth'.format(save_path, self.iter), [('generator', self.G)], [('optimizer_G', self.optm_G)], self.iter) if self.iter >= iters: break if self.iter >= iters: break print("Finished training iter %d. Saving model." % (self.iter)) if not os.path.exists('{:s}'.format(save_path)): os.makedirs('{:s}'.format(save_path)) # Save final checkpoint save_ckpt('{:s}/g_{:s}.pth'.format(save_path, "final"), [('generator', self.G)], [('optimizer_G', self.optm_G)], self.iter) save_ckpt('{:s}/g_{:s}_{:d}.pth'.format(save_path, "final", self.iter), [('generator', self.G)], [('optimizer_G', self.optm_G)], self.iter) def test(self, test_loader, result_save_path): self.G.eval() for para in self.G.parameters(): para.requires_grad = False count = 0 for items in test_loader: gt_images, masks, masked_images = self.__cuda__(*items) masks = torch.cat([masks] * 3, dim=1) fake_B, mask = self.G(masked_images, masks) comp_B = fake_B * (1 - masks) + gt_images * masks if not os.path.exists('{:s}/results'.format(result_save_path)): os.makedirs('{:s}/results'.format(result_save_path)) for k in range(comp_B.size(0)): count += 1 grid = make_grid(comp_B[k:k + 1]) file_path = '{:s}/results/img_{:d}.png'.format( result_save_path, count) save_image(grid, file_path) grid = make_grid(masked_images[k:k + 1] + 1 - masks[k:k + 1]) file_path = '{:s}/results/masked_img_{:d}.png'.format( result_save_path, count) save_image(grid, file_path) def forward(self, masked_image, mask, gt_image): self.real_A = masked_image self.real_B = gt_image self.mask = mask fake_B, _ = self.G(masked_image, mask) self.fake_B = fake_B self.comp_B = self.fake_B * (1 - mask) + self.real_B * mask def update_parameters(self): self.update_G() self.update_D() def update_G(self): self.optm_G.zero_grad() loss_G = self.get_g_loss() loss_G.backward() self.optm_G.step() def update_D(self): return def get_g_loss(self): real_B = self.real_B fake_B = self.fake_B comp_B = self.comp_B real_B_feats = self.lossNet(real_B) fake_B_feats = self.lossNet(fake_B) comp_B_feats = self.lossNet(comp_B) tv_loss = self.TV_loss(comp_B * (1 - self.mask)) style_loss = self.style_loss(real_B_feats, fake_B_feats) + self.style_loss( real_B_feats, comp_B_feats) preceptual_loss = self.preceptual_loss( real_B_feats, fake_B_feats) + self.preceptual_loss( real_B_feats, comp_B_feats) valid_loss = self.l1_loss(real_B, fake_B, self.mask) hole_loss = self.l1_loss(real_B, fake_B, (1 - self.mask)) loss_G = (tv_loss * 0.1 + style_loss * 120 + preceptual_loss * 0.05 + valid_loss * 1 + hole_loss * 6) self.l1_loss_val += valid_loss.detach() + hole_loss.detach() return loss_G def l1_loss(self, f1, f2, mask=1): return torch.mean(torch.abs(f1 - f2) * mask) def style_loss(self, A_feats, B_feats): assert len(A_feats) == len( B_feats ), "the length of two input feature maps lists should be the same" loss_value = 0.0 for i in range(len(A_feats)): A_feat = A_feats[i] B_feat = B_feats[i] _, c, w, h = A_feat.size() A_feat = A_feat.view(A_feat.size(0), A_feat.size(1), A_feat.size(2) * A_feat.size(3)) B_feat = B_feat.view(B_feat.size(0), B_feat.size(1), B_feat.size(2) * B_feat.size(3)) A_style = torch.matmul(A_feat, A_feat.transpose(2, 1)) B_style = torch.matmul(B_feat, B_feat.transpose(2, 1)) loss_value += torch.mean( torch.abs(A_style - B_style) / (c * w * h)) return loss_value def TV_loss(self, x): h_x = x.size(2) w_x = x.size(3) h_tv = torch.mean(torch.abs(x[:, :, 1:, :] - x[:, :, :h_x - 1, :])) w_tv = torch.mean(torch.abs(x[:, :, :, 1:] - x[:, :, :, :w_x - 1])) return h_tv + w_tv def preceptual_loss(self, A_feats, B_feats): assert len(A_feats) == len( B_feats ), "the length of two input feature maps lists should be the same" loss_value = 0.0 for i in range(len(A_feats)): A_feat = A_feats[i] B_feat = B_feats[i] loss_value += torch.mean(torch.abs(A_feat - B_feat)) return loss_value def __cuda__(self, *args): return (item.to(self.device) for item in args)