def train(): opt = TrainOptions().parse() if opt.distributed: init_dist() opt.batchSize = opt.batchSize // len(opt.gpu_ids) ### setup dataset data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() pose = 'pose' in opt.dataset_mode ### setup trainer trainer = Trainer(opt, data_loader) ### setup models model, flowNet = create_model(opt, trainer.start_epoch) flow_gt = conf_gt = [None] * 2 for epoch in range(trainer.start_epoch, opt.niter + opt.niter_decay + 1): trainer.start_of_epoch(epoch, model, data_loader) n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu for idx, data in enumerate(dataset, start=trainer.epoch_iter): trainer.start_of_iter() if not opt.no_flow_gt: data_list = [ data['tgt_label'], data['ref_label'] ] if pose else [data['tgt_image'], data['ref_image']] flow_gt, conf_gt = flowNet(data_list, epoch) data_list = [ data['tgt_label'], data['tgt_image'], flow_gt, conf_gt ] data_ref_list = [data['ref_label'], data['ref_image']] data_prev = [None, None] ############## Forward Pass ###################### for t in range(0, n_frames_total, n_frames_load): data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + data_prev g_losses, generated, data_prev = model( data_list_t, save_images=trainer.save, mode='generator') g_losses = loss_backward(opt, g_losses, model.module.optimizer_G) d_losses = model(data_list_t, mode='discriminator') d_losses = loss_backward(opt, d_losses, model.module.optimizer_D) loss_dict = dict( zip(model.module.lossCollector.loss_names, g_losses + d_losses)) if trainer.end_of_iter(loss_dict, generated + data_list + data_ref_list, model): break trainer.end_of_epoch(model)
def train(): opt = TrainOptions().parse() if opt.distributed: init_dist() opt.batchSize = opt.batchSize // len(opt.gpu_ids) ### setup dataset data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() ### setup trainer trainer = Trainer(opt, data_loader) ### setup models model, flowNet = create_model(opt, trainer.start_epoch) flow_gt = conf_gt = [None] * 3 ref_idx_fix = torch.zeros([opt.batchSize]) for epoch in tqdm(range(trainer.start_epoch, opt.niter + opt.niter_decay + 1)): trainer.start_of_epoch(epoch, model, data_loader) n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu for idx, data in enumerate(tqdm(dataset), start=trainer.epoch_iter): trainer.start_of_iter() if not opt.warp_ani: data.update({'ani_image':None, 'ani_lmark':None, 'cropped_images':None, 'cropped_lmarks':None }) if not opt.no_flow_gt: data_list = [data['tgt_mask_images'], data['cropped_images'], data['warping_ref'], data['ani_image']] flow_gt, conf_gt = flowNet(data_list, epoch) data_list = [data['tgt_label'], data['tgt_image'], data['tgt_template'], data['cropped_images'], flow_gt, conf_gt] data_ref_list = [data['ref_label'], data['ref_image']] data_prev = [None, None, None] data_ani = [data['warping_ref_lmark'], data['warping_ref'], data['ori_warping_refs'], data['ani_lmark'], data['ani_image']] ############## Forward Pass ###################### prevs = {"raw_images":[], "synthesized_images":[], \ "prev_warp_images":[], "prev_weights":[], \ "ani_warp_images":[], "ani_weights":[], \ "ref_warp_images":[], "ref_weights":[], \ "ref_flows":[], "prev_flows":[], "ani_flows":[], \ "ani_syn":[]} for t in range(0, n_frames_total, n_frames_load): data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + \ get_data_t(data_ani, n_frames_load, t) + data_prev g_losses, generated, data_prev, ref_idx = model(data_list_t, save_images=trainer.save, mode='generator', ref_idx_fix=ref_idx_fix) g_losses = loss_backward(opt, g_losses, model.module.optimizer_G) d_losses, _ = model(data_list_t, mode='discriminator', ref_idx_fix=ref_idx_fix) d_losses = loss_backward(opt, d_losses, model.module.optimizer_D) # store previous store_prev(generated, prevs) loss_dict = dict(zip(model.module.lossCollector.loss_names, g_losses + d_losses)) output_data_list = [prevs] + [data['ref_image']] + data_ani + data_list + [data['tgt_mask_images']] if trainer.end_of_iter(loss_dict, output_data_list, model): break # pdb.set_trace() trainer.end_of_epoch(model)