def finetune(self, ref_labels, ref_images): train_names = ['fc', 'conv_img', 'up'] params, _ = self.get_train_params(self.netG, train_names) self.optimizer_G = self.get_optimizer(params, for_discriminator=False) update_D = True if update_D: params = list(self.netD.parameters()) if self.add_face_D: params += list(self.netDf.parameters()) self.optimizer_D = self.get_optimizer(params, for_discriminator=True) iterations = 100 for it in range(1, iterations + 1): idx = np.random.randint(ref_labels.size(1)) tgt_label, tgt_image = random_roll([ref_labels[:,idx], ref_images[:,idx]]) tgt_label, tgt_image = tgt_label.unsqueeze(1), tgt_image.unsqueeze(1) g_losses, generated, prev = self.forward_generator(tgt_label, tgt_image, ref_labels, ref_images) g_losses = loss_backward(self.opt, g_losses, self.optimizer_G) d_losses = [] if update_D: d_losses = self.forward_discriminator(tgt_label, tgt_image, ref_labels, ref_images) d_losses = loss_backward(self.opt, d_losses, self.optimizer_D) if (it % 10) == 0: message = '(iters: %d) ' % it loss_dict = dict(zip(self.lossCollector.loss_names, g_losses + d_losses)) for k, v in loss_dict.items(): if v != 0: message += '%s: %.3f ' % (k, v) print(message)
def train(): opt = TrainOptions().parse() if opt.distributed: init_dist() opt.batchSize = opt.batchSize // len(opt.gpu_ids) ### setup dataset data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() pose = 'pose' in opt.dataset_mode ### setup trainer trainer = Trainer(opt, data_loader) ### setup models model, flowNet = create_model(opt, trainer.start_epoch) flow_gt = conf_gt = [None] * 2 for epoch in range(trainer.start_epoch, opt.niter + opt.niter_decay + 1): trainer.start_of_epoch(epoch, model, data_loader) n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu for idx, data in enumerate(dataset, start=trainer.epoch_iter): trainer.start_of_iter() if not opt.no_flow_gt: data_list = [ data['tgt_label'], data['ref_label'] ] if pose else [data['tgt_image'], data['ref_image']] flow_gt, conf_gt = flowNet(data_list, epoch) data_list = [ data['tgt_label'], data['tgt_image'], flow_gt, conf_gt ] data_ref_list = [data['ref_label'], data['ref_image']] data_prev = [None, None] ############## Forward Pass ###################### for t in range(0, n_frames_total, n_frames_load): data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + data_prev g_losses, generated, data_prev = model( data_list_t, save_images=trainer.save, mode='generator') g_losses = loss_backward(opt, g_losses, model.module.optimizer_G) d_losses = model(data_list_t, mode='discriminator') d_losses = loss_backward(opt, d_losses, model.module.optimizer_D) loss_dict = dict( zip(model.module.lossCollector.loss_names, g_losses + d_losses)) if trainer.end_of_iter(loss_dict, generated + data_list + data_ref_list, model): break trainer.end_of_epoch(model)
def finetune(self, ref_labels, ref_images, warp_ref_lmark, warp_ref_img, ani_img, ani_lmark): train_names = ['fc', 'conv_img', 'up'] params, _ = self.get_train_params(self.netG, train_names) self.optimizer_G = self.get_optimizer(params, for_discriminator=False) update_D = True if update_D: params = list(self.netD.parameters()) self.optimizer_D = self.get_optimizer(params, for_discriminator=True) iterations = 70 for it in range(1, iterations + 1): idx = np.random.randint(ref_labels.size(1)) tgt_label, tgt_image = ref_labels[:, idx].unsqueeze( 1), ref_images[:, idx].unsqueeze(1) g_losses, generated, prev, _ = self.forward_generator(tgt_label=tgt_label, tgt_image=tgt_image, \ tgt_template=1, tgt_crop_image=None, \ flow_gt=[None]*3, conf_gt=[None]*3, \ ref_labels=ref_labels, ref_images=ref_images, \ warp_ref_lmark=warp_ref_lmark.unsqueeze(1), warp_ref_img=warp_ref_img.unsqueeze(1)) g_losses = loss_backward(self.opt, g_losses, self.optimizer_G) d_losses = [] if update_D: d_losses, _ = self.forward_discriminator( tgt_label, tgt_image, ref_labels, ref_images, warp_ref_lmark=warp_ref_lmark.unsqueeze(1), warp_ref_img=warp_ref_img.unsqueeze(1)) d_losses = loss_backward(self.opt, d_losses, self.optimizer_D) if (it % 10) == 0: message = '(iters: %d) ' % it loss_dict = dict( zip(self.lossCollector.loss_names, g_losses + d_losses)) for k, v in loss_dict.items(): if v != 0: message += '%s: %.3f ' % (k, v) print(message)
def forward(self, opt, data, epoch, n_frames_total, n_frames_load, save_images): # ground truth flownet if not opt.no_flow_gt: data_list = [ data['tgt_image'], data['cropped_images'], data['warping_ref'], data['ani_image'] ] flow_gt, conf_gt = self.flowNet(data_list, epoch) else: flow_gt, conf_gt = None, None data_list = [ data['tgt_label'], data['tgt_image'], data['cropped_images'], flow_gt, conf_gt ] data_ref_list = [data['ref_label'], data['ref_image']] data_prev = [None, None, None] data_ani = [ data['warping_ref_lmark'], data['warping_ref'], data['ani_lmark'], data['ani_image'] ] ############## Forward Pass ###################### for t in range(0, n_frames_total, n_frames_load): data_list_t = self.get_data_t(data_list, n_frames_load, t) + data_ref_list + \ self.get_data_t(data_ani, n_frames_load, t) + data_prev g_losses, generated, data_prev, ref_idx = self.v2vModel( data_list_t, save_images=save_images, mode='generator', ref_idx_fix=None) g_losses = loss_backward(opt, g_losses, self.v2vModel.module.optimizer_G) d_losses, _ = self.v2vModel(data_list_t, mode='discriminator', ref_idx_fix=None) d_losses = loss_backward(opt, d_losses, self.v2vModel.module.optimizer_D) return generated, flow_gt, conf_gt
def finetune_call_multi(self, tgt_label_list, tgt_image_list, ref_label_list, ref_image_list, warp_ref_lmark_list, warp_ref_img_list, ani_lmark_list=None, ani_img_list=None, iterations=0): train_names = ['fc', 'conv_img', 'up'] params, _ = self.get_train_params(self.netG, train_names) self.optimizer_G = self.get_optimizer(params, for_discriminator=False) update_D = True if update_D: params = list(self.netD.parameters()) self.optimizer_D = self.get_optimizer(params, for_discriminator=True) for iteration in tqdm(range(iterations)): for data_id in tqdm(range(len(tgt_label_list))): tgt_labels, tgt_images, ref_labels, ref_images, warp_ref_lmark, warp_ref_img, ani_lmark, ani_img = \ encode_input_finetune(self.opt, data_list=[tgt_label_list[data_id], tgt_image_list[data_id], \ ref_label_list[data_id], ref_image_list[data_id], \ warp_ref_lmark_list[data_id], warp_ref_img_list[data_id], \ ani_lmark_list[data_id] if ani_lmark_list is not None else None, \ ani_img_list[data_id] if ani_img_list is not None else None], dummy_bs=0) idx = np.random.randint(tgt_labels.size(1)) tgt_label, tgt_image = tgt_labels[:, idx].unsqueeze( 1), tgt_images[:, idx].unsqueeze(1) ref_labels_finetune, ref_images_finetune = ref_labels, ref_images warp_ref_lmark_finetune, warp_ref_img_finetune = warp_ref_lmark, warp_ref_img ani_lmark_finetune, ani_img_finetune = ani_lmark, ani_img if ani_img is not None: assert ani_img.shape[1] == ani_lmark.shape[ 1] == tgt_labels.shape[1] ani_img_finetune = ani_img[:, idx].unsqueeze(1) ani_lmark_finetune = ani_lmark[:, idx].unsqueeze(1) if self.opt.n_shot < ref_labels.shape[1]: idxs = np.random.choice(ref_labels.shape[1], self.opt.n_shot) ref_labels_finetune = ref_labels[:, idxs] ref_images_finetune = ref_images[:, idxs] if warp_ref_lmark.shape[1] > 1: if self.opt.n_shot >= ref_labels.shape[1]: idxs = np.random.choice(ref_labels.shape[1], self.opt.n_shot) warp_ref_lmark_finetune = warp_ref_lmark[:, idxs[0]].unsqueeze( 1) warp_ref_img_finetune = warp_ref_img[:, idxs[0]].unsqueeze(1) g_losses, generated, prev, _ = self.forward_generator(tgt_label=tgt_label, tgt_image=tgt_image, \ tgt_template=1, tgt_crop_image=None, \ flow_gt=[None]*3, conf_gt=[None]*3, \ ref_labels=ref_labels_finetune, ref_images=ref_images_finetune, \ warp_ref_lmark=warp_ref_lmark_finetune, warp_ref_img=warp_ref_img_finetune, \ ani_lmark=ani_lmark_finetune, ani_img=ani_img_finetune) g_losses = loss_backward(self.opt, g_losses, self.optimizer_G) d_losses = [] if update_D: d_losses, _ = self.forward_discriminator( tgt_label, tgt_image, ref_labels_finetune, ref_images_finetune, warp_ref_lmark=warp_ref_lmark_finetune, warp_ref_img=warp_ref_img_finetune) d_losses = loss_backward(self.opt, d_losses, self.optimizer_D) if (iteration % 10) == 0: message = '(iters: %d) ' % iteration loss_dict = dict( zip(self.lossCollector.loss_names, g_losses + d_losses)) for k, v in loss_dict.items(): if v != 0: message += '%s: %.3f ' % (k, v) print(message) self.opt.finetune = False
def train(): opt = TrainOptions().parse() if opt.distributed: init_dist() opt.batchSize = opt.batchSize // len(opt.gpu_ids) ### setup dataset data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() ### setup trainer trainer = Trainer(opt, data_loader) ### setup models model, flowNet = create_model(opt, trainer.start_epoch) flow_gt = conf_gt = [None] * 3 ref_idx_fix = torch.zeros([opt.batchSize]) for epoch in tqdm(range(trainer.start_epoch, opt.niter + opt.niter_decay + 1)): trainer.start_of_epoch(epoch, model, data_loader) n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu for idx, data in enumerate(tqdm(dataset), start=trainer.epoch_iter): trainer.start_of_iter() if not opt.warp_ani: data.update({'ani_image':None, 'ani_lmark':None, 'cropped_images':None, 'cropped_lmarks':None }) if not opt.no_flow_gt: data_list = [data['tgt_mask_images'], data['cropped_images'], data['warping_ref'], data['ani_image']] flow_gt, conf_gt = flowNet(data_list, epoch) data_list = [data['tgt_label'], data['tgt_image'], data['tgt_template'], data['cropped_images'], flow_gt, conf_gt] data_ref_list = [data['ref_label'], data['ref_image']] data_prev = [None, None, None] data_ani = [data['warping_ref_lmark'], data['warping_ref'], data['ori_warping_refs'], data['ani_lmark'], data['ani_image']] ############## Forward Pass ###################### prevs = {"raw_images":[], "synthesized_images":[], \ "prev_warp_images":[], "prev_weights":[], \ "ani_warp_images":[], "ani_weights":[], \ "ref_warp_images":[], "ref_weights":[], \ "ref_flows":[], "prev_flows":[], "ani_flows":[], \ "ani_syn":[]} for t in range(0, n_frames_total, n_frames_load): data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + \ get_data_t(data_ani, n_frames_load, t) + data_prev g_losses, generated, data_prev, ref_idx = model(data_list_t, save_images=trainer.save, mode='generator', ref_idx_fix=ref_idx_fix) g_losses = loss_backward(opt, g_losses, model.module.optimizer_G) d_losses, _ = model(data_list_t, mode='discriminator', ref_idx_fix=ref_idx_fix) d_losses = loss_backward(opt, d_losses, model.module.optimizer_D) # store previous store_prev(generated, prevs) loss_dict = dict(zip(model.module.lossCollector.loss_names, g_losses + d_losses)) output_data_list = [prevs] + [data['ref_image']] + data_ani + data_list + [data['tgt_mask_images']] if trainer.end_of_iter(loss_dict, output_data_list, model): break # pdb.set_trace() trainer.end_of_epoch(model)