def forward(self, lr_curr, lr_prev, hr_prev): """ Parameters: :param lr_curr: the current lr data in shape nchw :param lr_prev: the previous lr data in shape nchw :param hr_prev: the previous hr data in shape nc(4h)(4w) """ # estimate lr flow (lr_curr -> lr_prev) lr_flow = self.fnet(lr_curr, lr_prev) # pad if size is not a multiple of 8 pad_h = lr_curr.size(2) - lr_curr.size(2) // 8 * 8 pad_w = lr_curr.size(3) - lr_curr.size(3) // 8 * 8 lr_flow_pad = F.pad(lr_flow, (0, pad_w, 0, pad_h), 'reflect') # upsample lr flow hr_flow = self.scale * self.upsample_func(lr_flow_pad) # warp hr_prev hr_prev_warp = backward_warp(hr_prev, hr_flow) # compute hr_curr hr_curr = self.srnet(lr_curr, space_to_depth(hr_prev_warp, self.scale)) return hr_curr
def forward_sequence(self, lr_data): """ Parameters: :param lr_data: lr data in shape ntchw """ n, t, c, lr_h, lr_w = lr_data.size() hr_h, hr_w = lr_h * self.scale, lr_w * self.scale # calculate optical flows lr_prev = lr_data[:, :-1, ...].reshape(n * (t - 1), c, lr_h, lr_w) lr_curr = lr_data[:, 1:, ...].reshape(n * (t - 1), c, lr_h, lr_w) lr_flow = self.fnet(lr_curr, lr_prev) # n*(t-1),2,h,w # upsample lr flows hr_flow = self.scale * self.upsample_func(lr_flow) hr_flow = hr_flow.view(n, (t - 1), 2, hr_h, hr_w) # compute the first hr data hr_data = [] hr_prev = self.srnet( lr_data[:, 0, ...], torch.zeros(n, self.out_nc, lr_h, lr_w, dtype=torch.float32, device=lr_data.device)) hr_data.append(hr_prev) # compute the remaining hr data for i in range(1, t): # warp hr_prev hr_prev_warp = backward_warp(hr_prev, hr_flow[:, i - 1, ...]) # compute hr_curr hr_curr = self.srnet(lr_data[:, i, ...], space_to_depth(hr_prev_warp, self.scale)) # save and update hr_data.append(hr_curr) hr_prev = hr_curr hr_data = torch.stack(hr_data, dim=1) # n,t,c,hr_h,hr_w # construct output dict ret_dict = { 'hr_data': hr_data, # n,t,c,hr_h,hr_w 'hr_flow': hr_flow, # n,t,2,hr_h,hr_w 'lr_prev': lr_prev, # n(t-1),c,lr_h,lr_w 'lr_curr': lr_curr, # n(t-1),c,lr_h,lr_w 'lr_flow': lr_flow, # n(t-1),2,lr_h,lr_w } return ret_dict
def train(self, data): """ Function of mini-batch training Parameters: :param data: a batch of training data (lr & gt) in shape ntchw """ # ------------ prepare data ------------ # lr_data, gt_data = data['lr'], data['gt'] # ------------ clear optim ------------ # self.net_G.train() self.optim_G.zero_grad() # ------------ forward G ------------ # net_G_output_dict = self.net_G.forward_sequence(lr_data) hr_data = net_G_output_dict['hr_data'] # ------------ optimize G ------------ # loss_G = 0 self.log_dict = OrderedDict() # pixel loss pix_w = self.opt['train']['pixel_crit'].get('weight', 1) loss_pix_G = pix_w * self.pix_crit(hr_data, gt_data) loss_G += loss_pix_G self.log_dict['l_pix_G'] = loss_pix_G.item() # warping loss if self.warp_crit is not None: # warp lr_prev according to lr_flow lr_curr = net_G_output_dict['lr_curr'] lr_prev = net_G_output_dict['lr_prev'] lr_flow = net_G_output_dict['lr_flow'] lr_warp = net_utils.backward_warp(lr_prev, lr_flow) warp_w = self.opt['train']['warping_crit'].get('weight', 1) loss_warp_G = warp_w * self.warp_crit(lr_warp, lr_curr) loss_G += loss_warp_G self.log_dict['l_warp_G'] = loss_warp_G.item() # optimize loss_G.backward() self.optim_G.step()
def profile(self, lr_size, device): gflops_dict, params_dict = OrderedDict(), OrderedDict() # generate dummy input data lr_curr, lr_prev, hr_prev = self.generate_dummy_data(lr_size, device) # profile module 1: flow estimation module lr_flow = register(self.fnet, [lr_curr, lr_prev]) gflops_dict['FNet'], params_dict['FNet'] = parse_model_info(self.fnet) # profile module 2: sr module pad_h = lr_curr.size(2) - lr_curr.size(2) // 8 * 8 pad_w = lr_curr.size(3) - lr_curr.size(3) // 8 * 8 lr_flow_pad = F.pad(lr_flow, (0, pad_w, 0, pad_h), 'reflect') hr_flow = self.scale * self.upsample_func(lr_flow_pad) hr_prev_warp = backward_warp(hr_prev, hr_flow) _ = register( self.srnet, [lr_curr, space_to_depth(hr_prev_warp, self.scale)]) gflops_dict['SRNet'], params_dict['SRNet'] = parse_model_info( self.srnet) return gflops_dict, params_dict
def train(self, data): """ Function for mini-batch training Parameters: :param data: a batch of training tensor with shape NTCHW """ # ------------ prepare data ------------ # lr_data, gt_data = data['lr'], data['gt'] n, t, lr_c, lr_h, lr_w = lr_data.size() _, _, hr_c, gt_h, gt_w = gt_data.size() # generate bicubic upsampled data bi_data = self.net_G.upsample_func( lr_data.view(n * t, lr_c, lr_h, lr_w)).view(n, t, lr_c, gt_h, gt_w) # augment data for pingpong criterion if self.pp_crit is not None: # i.e., (0,1,2,...,t-2,t-1) -> (0,1,2,...,t-2,t-1,t-2,...,2,1,0) lr_rev = lr_data.flip(1)[:, 1:, ...] gt_rev = gt_data.flip(1)[:, 1:, ...] bi_rev = bi_data.flip(1)[:, 1:, ...] lr_data = torch.cat([lr_data, lr_rev], dim=1) gt_data = torch.cat([gt_data, gt_rev], dim=1) bi_data = torch.cat([bi_data, bi_rev], dim=1) # ------------ clear optimizers ------------ # self.net_G.train() self.net_D.train() self.optim_G.zero_grad() self.optim_D.zero_grad() # ------------ forward G ------------ # net_G_output_dict = self.net_G.forward_sequence(lr_data) hr_data = net_G_output_dict['hr_data'] # ------------ forward D ------------ # for param in self.net_D.parameters(): param.requires_grad = True # feed additional data net_D_input_dict = { 'net_G': self.net_G, 'lr_data': lr_data, 'bi_data': bi_data, 'use_pp_crit': (self.pp_crit is not None), 'crop_border_ratio': self.opt['train']['discriminator'].get( 'crop_border_ratio', 1.0) } net_D_input_dict.update(net_G_output_dict) # forward real sequence (gt) real_pred, net_D_oputput_dict = self.net_D.forward_sequence( gt_data, net_D_input_dict) # reuse internal data (e.g., lr optical flow) to reduce computations net_D_input_dict.update(net_D_oputput_dict) # forward fake sequence (hr) fake_pred, _ = self.net_D.forward_sequence( hr_data.detach(), net_D_input_dict) # ------------ optimize D ------------ # self.log_dict = OrderedDict() real_pred_D, fake_pred_D = real_pred[0], fake_pred[0] # select D update policy update_policy = self.opt['train']['discriminator']['update_policy'] if update_policy == 'adaptive': # update D adaptively logged_real_pred_D = torch.log(torch.sigmoid(real_pred_D) + 1e-8) logged_fake_pred_D = torch.log(torch.sigmoid(fake_pred_D) + 1e-8) distance = logged_real_pred_D.mean() - logged_fake_pred_D.mean() threshold = self.opt['train']['discriminator']['update_threshold'] upd_D = distance.item() < threshold else: upd_D = True if upd_D: self.cnt_upd_D += 1 real_loss_D = self.gan_crit(real_pred_D, 1) fake_loss_D = self.gan_crit(fake_pred_D, 0) loss_D = real_loss_D + fake_loss_D # update D loss_D.backward() self.optim_D.step() else: loss_D = torch.zeros(1) # logging self.log_dict['l_gan_D'] = loss_D.item() self.log_dict['p_real_D'] = real_pred_D.mean().item() self.log_dict['p_fake_D'] = fake_pred_D.mean().item() if update_policy == 'adaptive': self.log_dict['distance'] = distance.item() self.log_dict['n_upd_D'] = self.cnt_upd_D # ------------ optimize G ------------ # for param in self.net_D.parameters(): param.requires_grad = False # calculate losses loss_G = 0 # pixel (pix) loss if self.pix_crit is not None: pix_w = self.opt['train']['pixel_crit'].get('weight', 1) loss_pix_G = pix_w * self.pix_crit(hr_data, gt_data) loss_G += loss_pix_G self.log_dict['l_pix_G'] = loss_pix_G.item() # warping (warp) loss if self.warp_crit is not None: lr_curr = net_G_output_dict['lr_curr'] lr_prev = net_G_output_dict['lr_prev'] lr_flow = net_G_output_dict['lr_flow'] lr_warp = net_utils.backward_warp(lr_prev, lr_flow) warp_w = self.opt['train']['warping_crit'].get('weight', 1) loss_warp_G = warp_w * self.warp_crit(lr_warp, lr_curr) loss_G += loss_warp_G self.log_dict['l_warp_G'] = loss_warp_G.item() # feature (feat) loss if self.feat_crit is not None: hr_merge = hr_data.view(-1, hr_c, gt_h, gt_w) gt_merge = gt_data.view(-1, hr_c, gt_h, gt_w) hr_feat_lst = self.net_F(hr_merge) gt_feat_lst = self.net_F(gt_merge) loss_feat_G = 0 for hr_feat, gt_feat in zip(hr_feat_lst, gt_feat_lst): loss_feat_G += self.feat_crit(hr_feat, gt_feat.detach()) feat_w = self.opt['train']['feature_crit'].get('weight', 1) loss_feat_G = feat_w * loss_feat_G loss_G += loss_feat_G self.log_dict['l_feat_G'] = loss_feat_G.item() # ping-pong (pp) loss if self.pp_crit is not None: tempo_extent = self.opt['train']['tempo_extent'] hr_data_fw = hr_data[:, :tempo_extent - 1, ...] # -------->| hr_data_bw = hr_data[:, tempo_extent:, ...].flip(1) # <--------| pp_w = self.opt['train']['pingpong_crit'].get('weight', 1) loss_pp_G = pp_w * self.pp_crit(hr_data_fw, hr_data_bw) loss_G += loss_pp_G self.log_dict['l_pp_G'] = loss_pp_G.item() # feature matching (fm) loss if self.fm_crit is not None: fake_pred, _ = self.net_D.forward_sequence(hr_data, net_D_input_dict) fake_feat_lst, real_feat_lst = fake_pred[-1], real_pred[-1] layer_norm = self.opt['train']['feature_matching_crit'].get( 'layer_norm', [12.0, 14.0, 24.0, 100.0]) loss_fm_G = 0 for i in range(len(real_feat_lst)): fake_feat, real_feat = fake_feat_lst[i], real_feat_lst[i] loss_fm_G += self.fm_crit( fake_feat, real_feat.detach()) / layer_norm[i] fm_w = self.opt['train']['feature_matching_crit'].get('weight', 1) loss_fm_G = fm_w * loss_fm_G loss_G += loss_fm_G self.log_dict['l_fm_G'] = loss_fm_G.item() # gan loss if self.fm_crit is None: fake_pred, _ = self.net_D.forward_sequence(hr_data, net_D_input_dict) fake_pred_G = fake_pred[0] gan_w = self.opt['train']['gan_crit'].get('weight', 1) loss_gan_G = gan_w * self.gan_crit(fake_pred_G, True) loss_G += loss_gan_G self.log_dict['l_gan_G'] = loss_gan_G.item() self.log_dict['p_fake_G'] = fake_pred_G.mean().item() # update G loss_G.backward() self.optim_G.step()
def forward_sequence(self, data, args_dict): """ :param data: should be either hr_data or gt_data :param args_dict: a dict including data/config needed here """ # ------------ setup params ------------ # net_G = args_dict['net_G'] lr_data = args_dict['lr_data'] bi_data = args_dict['bi_data'] hr_flow = args_dict['hr_flow'] n, t, lr_c, lr_h, lr_w = lr_data.size() _, _, hr_c, hr_h, hr_w = data.size() s_size = self.spatial_size t = t // 3 * 3 # discard other frames n_clip = n * t // 3 # total number of 3-frame clips in all batches c_size = int(s_size * args_dict['crop_border_ratio']) n_pad = (s_size - c_size) // 2 # ------------ compute forward & backward flow ------------ # if 'hr_flow_merge' not in args_dict: if args_dict['use_pp_crit']: hr_flow_bw = hr_flow[:, 0:t:3, ...] # e.g., frame1 -> frame0 hr_flow_idle = torch.zeros_like(hr_flow_bw) hr_flow_fw = hr_flow.flip(1)[:, 1:t:3, ...] else: lr_curr = lr_data[:, 1:t:3, ...] lr_curr = lr_curr.reshape(n_clip, lr_c, lr_h, lr_w) lr_next = lr_data[:, 2:t:3, ...] lr_next = lr_next.reshape(n_clip, lr_c, lr_h, lr_w) # compute forward flow lr_flow_fw = net_G.fnet(lr_curr, lr_next) hr_flow_fw = self.scale * net_G.upsample_func(lr_flow_fw) hr_flow_bw = hr_flow[:, 0:t:3, ...] # e.g., frame1 -> frame0 hr_flow_idle = torch.zeros_like(hr_flow_bw) # frame1 -> frame1 hr_flow_fw = hr_flow_fw.view(n, t // 3, 2, hr_h, hr_w) # frame1 -> frame2 # merge bw/idle/fw flows hr_flow_merge = torch.stack([hr_flow_bw, hr_flow_idle, hr_flow_fw], dim=2) # n,t//3,3,2,h,w # reshape and stop gradient propagation hr_flow_merge = hr_flow_merge.view(n_clip * 3, 2, hr_h, hr_w).detach() else: # reused data to reduce computations hr_flow_merge = args_dict['hr_flow_merge'] # ------------ build up inputs for D (3 parts) ------------ # # part 1: bicubic upsampled data (conditional inputs) cond_data = bi_data[:, :t, ...].reshape(n_clip, 3, lr_c, hr_h, hr_w) # note: permutation is not necessarily needed here, it's just to keep # the same impl. as TecoGAN-Tensorflow (i.e., rrrgggbbb) cond_data = cond_data.permute(0, 2, 1, 3, 4) cond_data = cond_data.reshape(n_clip, lr_c * 3, hr_h, hr_w) # part 2: original data orig_data = data[:, :t, ...].reshape(n_clip, 3, hr_c, hr_h, hr_w) orig_data = orig_data.permute(0, 2, 1, 3, 4) orig_data = orig_data.reshape(n_clip, hr_c * 3, hr_h, hr_w) # part 3: warped data warp_data = backward_warp( data[:, :t, ...].reshape(n * t, hr_c, hr_h, hr_w), hr_flow_merge) warp_data = warp_data.view(n_clip, 3, hr_c, hr_h, hr_w) warp_data = warp_data.permute(0, 2, 1, 3, 4) warp_data = warp_data.reshape(n_clip, hr_c * 3, hr_h, hr_w) # remove border to increase training stability as proposed in TecoGAN warp_data = F.pad(warp_data[..., n_pad:n_pad + c_size, n_pad:n_pad + c_size], (n_pad, ) * 4, mode='constant') # combine 3 parts together input_data = torch.cat([orig_data, warp_data, cond_data], dim=1) # ------------ classify ------------ # pred = self.forward(input_data) # out, feature_list # construct output dict (return other data beside pred) ret_dict = {'hr_flow_merge': hr_flow_merge} return pred, ret_dict
def train(self): # === prepare data === # lr_data, gt_data = self.lr_data, self.gt_data n, t, c, lr_h, lr_w = lr_data.size() _, _, _, gt_h, gt_w = gt_data.size() # generate bicubic upsampled data upsample_fn = self.get_bare_model(self.net_G).upsample_func bi_data = upsample_fn(lr_data.view(n * t, c, lr_h, lr_w)).view(n, t, c, gt_h, gt_w) # augment data for pingpong criterion # i.e., (0,1,2,...,t-2,t-1) -> (0,1,2,...,t-2,t-1,t-2,...,2,1,0) if self.pp_crit is not None: lr_rev = lr_data.flip(1)[:, 1:, ...] gt_rev = gt_data.flip(1)[:, 1:, ...] bi_rev = bi_data.flip(1)[:, 1:, ...] lr_data = torch.cat([lr_data, lr_rev], dim=1) gt_data = torch.cat([gt_data, gt_rev], dim=1) bi_data = torch.cat([bi_data, bi_rev], dim=1) # === initialize === # self.net_G.train() self.net_D.train() self.optim_G.zero_grad() self.optim_D.zero_grad() self.log_dict = OrderedDict() # === forward net_G === # net_G_output_dict = self.net_G(lr_data) hr_data = net_G_output_dict['hr_data'] # === forward net_D === # for param in self.net_D.parameters(): param.requires_grad = True # feed additional data net_D_input_dict = { 'net_G': self.get_bare_model(self.net_G), # TODO: check 'lr_data': lr_data, 'bi_data': bi_data, 'use_pp_crit': (self.pp_crit is not None), 'crop_border_ratio': self.opt['train']['discriminator'].get('crop_border_ratio', 1.0) } net_D_input_dict.update(net_G_output_dict) # forward real sequence (gt) real_pred, net_D_oputput_dict = self.net_D(gt_data, net_D_input_dict) # reuse internal data (e.g., optical flow) net_D_input_dict.update(net_D_oputput_dict) # forward fake sequence (hr) fake_pred, _ = self.net_D(hr_data.detach(), net_D_input_dict) # === optimize net_D === # real_pred_D, fake_pred_D = real_pred[0], fake_pred[0] # select D update policy update_policy = self.opt['train']['discriminator']['update_policy'] if update_policy == 'adaptive': # update D adaptively logged_real_pred_D = torch.log(torch.sigmoid(real_pred_D) + 1e-8).mean() logged_fake_pred_D = torch.log(torch.sigmoid(fake_pred_D) + 1e-8).mean() if self.dist: # synchronize dist.all_reduce(logged_real_pred_D) dist.all_reduce(logged_fake_pred_D) dist.barrier() logged_real_pred_D /= self.opt['world_size'] logged_fake_pred_D /= self.opt['world_size'] distance = (logged_real_pred_D - logged_fake_pred_D).item() upd_D = distance < self.opt['train']['discriminator'][ 'update_threshold'] else: upd_D = True if upd_D: self.cnt_upd_D += 1.0 real_loss_D = self.gan_crit(real_pred_D, True) fake_loss_D = self.gan_crit(fake_pred_D, False) loss_D = real_loss_D + fake_loss_D # update net_D loss_D.backward() self.optim_D.step() else: loss_D = torch.zeros(1) # logging self.log_dict['l_gan_D'] = loss_D.item() self.log_dict['p_real_D'] = real_pred_D.mean().item() self.log_dict['p_fake_D'] = fake_pred_D.mean().item() if update_policy == 'adaptive': self.log_dict['distance'] = distance self.log_dict['n_upd_D'] = self.cnt_upd_D # === optimize net_G === # for param in self.net_D.parameters(): param.requires_grad = False # calculate losses loss_G = 0 # pixel (pix) loss if self.pix_crit is not None: pix_w = self.opt['train']['pixel_crit'].get('weight', 1) loss_pix_G = pix_w * self.pix_crit(hr_data, gt_data) loss_G += loss_pix_G self.log_dict['l_pix_G'] = loss_pix_G.item() # warping (warp) loss if self.warp_crit is not None: lr_curr = net_G_output_dict['lr_curr'] lr_prev = net_G_output_dict['lr_prev'] lr_flow = net_G_output_dict['lr_flow'] lr_warp = net_utils.backward_warp(lr_prev, lr_flow) warp_w = self.opt['train']['warping_crit'].get('weight', 1) loss_warp_G = warp_w * self.warp_crit(lr_warp, lr_curr) loss_G += loss_warp_G self.log_dict['l_warp_G'] = loss_warp_G.item() # feature/perceptual (feat) loss if self.feat_crit is not None: hr_merge = hr_data.view(-1, c, gt_h, gt_w) gt_merge = gt_data.view(-1, c, gt_h, gt_w) hr_feat_lst = self.net_F(hr_merge) gt_feat_lst = self.net_F(gt_merge) loss_feat_G = 0 for hr_feat, gt_feat in zip(hr_feat_lst, gt_feat_lst): loss_feat_G += self.feat_crit(hr_feat, gt_feat.detach()) feat_w = self.opt['train']['feature_crit'].get('weight', 1) loss_feat_G = feat_w * loss_feat_G loss_G += loss_feat_G self.log_dict['l_feat_G'] = loss_feat_G.item() # ping-pong (pp) loss if self.pp_crit is not None: tempo_extent = self.opt['train']['tempo_extent'] hr_data_fw = hr_data[:, :tempo_extent - 1, ...] # -------->| hr_data_bw = hr_data[:, tempo_extent:, ...].flip(1) # <--------| pp_w = self.opt['train']['pingpong_crit'].get('weight', 1) loss_pp_G = pp_w * self.pp_crit(hr_data_fw, hr_data_bw) loss_G += loss_pp_G self.log_dict['l_pp_G'] = loss_pp_G.item() # feature matching (fm) loss if self.fm_crit is not None: fake_pred, _ = self.net_D(hr_data, net_D_input_dict) fake_feat_lst, real_feat_lst = fake_pred[-1], real_pred[-1] layer_norm = self.opt['train']['feature_matching_crit'].get( 'layer_norm', [12.0, 14.0, 24.0, 100.0]) loss_fm_G = 0 for i in range(len(real_feat_lst)): fake_feat, real_feat = fake_feat_lst[i], real_feat_lst[i] loss_fm_G += self.fm_crit(fake_feat, real_feat.detach()) / layer_norm[i] fm_w = self.opt['train']['feature_matching_crit'].get('weight', 1) loss_fm_G = fm_w * loss_fm_G loss_G += loss_fm_G self.log_dict['l_fm_G'] = loss_fm_G.item() # gan loss if self.fm_crit is None: fake_pred, _ = self.net_D(hr_data, net_D_input_dict) fake_pred_G = fake_pred[0] gan_w = self.opt['train']['gan_crit'].get('weight', 1) loss_gan_G = gan_w * self.gan_crit(fake_pred_G, True) loss_G += loss_gan_G self.log_dict['l_gan_G'] = loss_gan_G.item() self.log_dict['p_fake_G'] = fake_pred_G.mean().item() # update net_G loss_G.backward() self.optim_G.step()