def train(self, loaders): args = self.args nets = self.nets nets_ema = self.nets_ema optims = self.optims # fetch random validation images for debugging fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train') fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val') inputs_val = next(fetcher_val) # resume training if necessary if args.resume_iter > 0: self._load_checkpoint(args.resume_iter) # remember the initial value of ds weight initial_lambda_ds = args.lambda_ds print('Start training...') start_time = time.time() for i in range(args.resume_iter, args.total_iters): # fetch images and labels inputs = next(fetcher) x_real, y_org = inputs.x_src, inputs.y_src x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2 masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None # train the discriminator d_loss, d_losses_latent = compute_d_loss( nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks) self._reset_grad() d_loss.backward() optims.discriminator.step() d_loss, d_losses_ref = compute_d_loss( nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks) self._reset_grad() d_loss.backward() optims.discriminator.step() # train the generator g_loss, g_losses_latent = compute_g_loss( nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks) self._reset_grad() g_loss.backward() optims.generator.step() optims.mapping_network.step() optims.style_encoder.step() g_loss, g_losses_ref = compute_g_loss( nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks) self._reset_grad() g_loss.backward() optims.generator.step() # compute moving average of network parameters moving_average(nets.generator, nets_ema.generator, beta=0.999) moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999) moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999) # decay weight for diversity sensitive loss if args.lambda_ds > 0: args.lambda_ds -= (initial_lambda_ds / args.ds_iter) # print out log info if (i+1) % args.print_every == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed))[:-7] log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters) all_losses = dict() for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']): for key, value in loss.items(): all_losses[prefix + key] = value all_losses['G/lambda_ds'] = args.lambda_ds log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()]) print(log) wandb.log({"Losses":all_losses}) # generate images for debugging if (i+1) % args.sample_every == 0: os.makedirs(args.sample_dir, exist_ok=True) utils.debug_image(nets_ema, args, inputs=inputs_val, step=i+1) # save model checkpoints if (i+1) % args.save_every == 0: self._save_checkpoint(step=i+1) # compute FID and LPIPS if necessary if (i+1) % args.eval_every == 0: calculate_metrics(nets_ema, args, i+1, mode='latent') calculate_metrics(nets_ema, args, i+1, mode='reference')
def train(self, loaders): args = self.args nets = self.nets #nets_ema = self.nets_ema optims = self.optims # resume training if necessary #if args.resume_iter > 0: # self._load_checkpoint( str(e) + '_' + str(args.resume_iter) ) """ define the fetcher for dataloading """ fetcher_tr = InputFetcher(loaders.train, 'train') fetcher_val = InputFetcher(loaders.val, 'val') print('Start training...') start_time = time.time() for e in range(args.epoch): for i in range(args.resume_iter, len(fetcher_tr)): #args.total_iters ): """ get input from training and validation from fetcher """ inputs = next(fetcher_tr) inputs_val = next(fetcher_val) gt_land, gt, gt_mask, prior = inputs.gt_land, inputs.gt, inputs.gt_mask, inputs.prior gt_land = gt_land.detach() gt = gt.detach() gt_mask = gt_mask.detach() prior = prior.detach() #gt.shape ...: (batch, sync_t, c, h, w) #prior.shape: (batch, sync_t, c*2, h, w) gt_land = gt_land.flatten(0, 1) # (batch*sync_t, c*3, h, w) gt = gt.flatten(0, 1) gt_mask = gt_mask.flatten(0, 1) prior = prior.flatten(0, 1) # (batch*sync_t, c*2, h, w) #utils.save_image( torch.cat( (gt[:,:3], gt[:,3:]), dim=0 ) , './sample_gt.jpg') # utils.save_image(prior.view(5,2,3,256,256).view(10,3,256,256), './sample_prior.jpg') """ train the discriminator """ # d_losses.key(): real, fake, reg d_loss, d_losses = compute_d_loss(nets, args, gt_land, gt, gt_mask, prior) self._reset_grad() d_loss.backward() optims.discriminator.step() """ train the generator """ # g_losses.key(): adv, recon g_loss, g_losses = compute_g_loss(nets, args, gt_land, gt, gt_mask, prior) self._reset_grad() g_loss.backward() optims.generator.step() optims.style_encoder.step() ############# style encoder update # print out log info if (i + 1) % args.print_every == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed))[:-7] log = "Elapsed time [%s], Iteration [%i/%i], " % ( elapsed, i + 1, len(fetcher_tr)) all_losses = dict() for loss, prefix in zip([d_losses, g_losses], ['D', 'G']): for key, value in loss.items(): all_losses[prefix + key] = value log += ' '.join([ '%s: [%.4f]' % (key, value) for key, value in all_losses.items() ]) print(log) for k, v in all_losses.items(): self.summary_writer.add_scalar(k, v, e * i + i) # add # save model checkpoints if (i + 1) % args.save_every == 0: self._save_checkpoint(step=str(e) + '_' + str(i + 1)) # generate images for debugging with torch.no_grad(): if (i + 1) % args.sample_every == 0: os.makedirs(args.sample_dir, exist_ok=True) utils.debug_image(nets, args, inputs=inputs_val, step=str(e) + '_' + str(i + 1))
def train(self, loaders): place = paddle.fluid.CUDAPlace( self.args.whichgpu) if paddle.fluid.is_compiled_with_cuda( ) else paddle.fluid.CPUPlace() with fluid.dygraph.guard(place): args = self.args nets = self.nets nets_ema = self.nets_ema optims = self.optims # fetch random validation images for debugging fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train') fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val') inputs_val = next(fetcher_val) # resume training if necessary if args.resume_iter > 0: self._load_checkpoint(args.resume_iter) # remember the initial value of ds weight initial_lambda_ds = args.lambda_ds print('Start training...') start_time = time.time() for i in range(args.resume_iter, args.total_iters): # fetch images and labels inputs = next(fetcher) x_real, y_org = inputs.x_src, inputs.y_src x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2 x_real = fluid.dygraph.to_variable(x_real.astype('float32')) y_org = fluid.dygraph.to_variable(y_org) x_ref = fluid.dygraph.to_variable(x_ref.astype('float32')) x_ref2 = fluid.dygraph.to_variable(x_ref2.astype('float32')) y_trg = fluid.dygraph.to_variable(y_trg) z_trg = fluid.dygraph.to_variable(z_trg.astype('float32')) z_trg2 = fluid.dygraph.to_variable(z_trg2.astype('float32')) masks = nets.fan.get_heatmap( x_real) if args.w_hpf > 0 else None # train the discriminator # print('train the discriminator') d_loss, d_losses_latent = compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks) self._reset_grad() d_loss_avg = fluid.layers.mean(d_loss) d_loss_avg.backward() optims.discriminator.minimize(d_loss_avg) d_loss, d_losses_ref = compute_d_loss(nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks) self._reset_grad() d_loss_avg = fluid.layers.mean(d_loss) d_loss_avg.backward() optims.discriminator.minimize(d_loss_avg) # train the generator # print('train the generator 1st') g_loss, g_losses_latent = compute_g_loss( nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks) self._reset_grad() g_loss_avg = fluid.layers.mean(g_loss) g_loss_avg.backward() # stuck here with 1.8.x cpu version optims.generator.minimize(g_loss_avg) optims.mapping_network.minimize(g_loss_avg) optims.style_encoder.minimize(g_loss_avg) # print('train the generator 2nd') g_loss, g_losses_ref = compute_g_loss(nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks) self._reset_grad() g_loss_avg = fluid.layers.mean(g_loss) g_loss_avg.backward() optims.generator.minimize(g_loss_avg) # compute moving average of network parameters # print('compute moving average of network parameters') moving_average(nets.generator, nets_ema.generator, beta=0.999) moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999) moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999) # print('finish compute moving average of network parameters') # decay weight for diversity sensitive loss if args.lambda_ds > 0: args.lambda_ds -= (initial_lambda_ds / args.ds_iter) # print out log info if (i + 1) % args.print_every == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed))[:-7] log = "Elapsed time [%s], Iteration [%i/%i], " % ( elapsed, i + 1, args.total_iters) all_losses = dict() for loss, prefix in zip([ d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref ], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']): for key, value in loss.items(): all_losses[prefix + key] = value all_losses['G/lambda_ds'] = args.lambda_ds log += ' '.join([ '%s: [%.4f]' % (key, value) for key, value in all_losses.items() ]) print(log) # generate images for debugging if (i + 1) % args.sample_every == 0: os.makedirs(args.sample_dir, exist_ok=True) utils.debug_image(nets_ema, args, inputs=inputs_val, step=i + 1) # save model checkpoints if (i + 1) % args.save_every == 0: self._save_checkpoint(step=i + 1)
def train(self, loaders): args = self.args nets = self.nets nets_ema = self.nets_ema optims = self.optims # fetch random validation images for debugging fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train') fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val') inputs_val = next(fetcher_val) # resume training if necessary if args.resume_iter > 0: self._load_checkpoint(args.resume_iter) # remember the initial value of ds weight initial_lambda_ds = args.lambda_ds print('Start training...') start_time = time.time() for i in range(args.resume_iter, args.total_iters): # fetch images and labels inputs = next(fetcher) x_real, y_org = inputs.x_src, inputs.y_src x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2 masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None # train the discriminator d_loss, d_losses_latent = compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks) self._reset_grad() d_loss.backward() optims.discriminator.step() d_loss, d_losses_ref = compute_d_loss(nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks) self._reset_grad() d_loss.backward() optims.discriminator.step() # train the generator g_loss, g_losses_latent = compute_g_loss(nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks) self._reset_grad() g_loss.backward() optims.generator.step() optims.mapping_network.step() optims.style_encoder.step() g_loss, g_losses_ref = compute_g_loss(nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks) self._reset_grad() g_loss.backward() optims.generator.step() # compute moving average of network parameters moving_average(nets.generator, nets_ema.generator, beta=0.999) moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999) moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999) # decay weight for diversity sensitive loss if args.lambda_ds > 0: args.lambda_ds -= (initial_lambda_ds / args.ds_iter) # print out log info if (i + 1) % args.print_every == 0: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed))[:-7] log = "Elapsed time [%s], Iteration [%i/%i], " % ( elapsed, i + 1, args.total_iters) all_losses = dict() for loss, prefix in zip([ d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref ], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']): for key, value in loss.items(): all_losses[prefix + key] = value all_losses['G/lambda_ds'] = args.lambda_ds log += ' '.join([ '%s: [%.4f]' % (key, value) for key, value in all_losses.items() ]) print(log) # generate images for debugging if (i + 1) % args.sample_every == 0: os.makedirs(args.sample_dir, exist_ok=True) utils.debug_image(nets_ema, args, inputs=inputs_val, step=i + 1) # save model checkpoints if (i + 1) % args.save_every == 0: self._save_checkpoint(step=i + 1) # compute FID and LPIPS if necessary if (i + 1) % args.eval_every == 0: calculate_metrics(nets_ema, args, i + 1, mode='latent') calculate_metrics(nets_ema, args, i + 1, mode='reference') import adaiw if (i + 1) % args.print_learned == 0: alphas_list = [] for module in self.nets.generator.decode: if isinstance(module.norm1, adaiw.AdaIN): alphas_list.append([ ('norm_1_white', module.norm1.alpha_white_param, 'norm_1_color', module.norm1.alpha_color_param), ('norm_2_white', module.norm2.alpha_white_param, 'norm_2_color', module.norm2.alpha_color_param) ]) txt_f = os.path.join(args.notes_path, "alphas.txt") with open(txt_f, "a+") as f: f.write("{} iterations: \n".format(i + 1)) for item in alphas_list: f.write("{}\n".format(item)) if (i + 1) % args.print_std == 0: with open(os.path.join(args.notes_path, "std.txt"), "a+") as f: f.write("{} iterations: \n".format(i + 1)) for j, module in enumerate(self.nets.generator.decode): print([ module.std_b4_norm_1.item(), module.std_b4_norm_2.item(), module.std_b4_join.item(), module.std_b4_output.item() ], file=f) if (i + 1) % args.print_sqrt_error == 0: with open(os.path.join(args.notes_path, "sqrt_errors.txt"), "a+") as f: f.write("{} iterations: \n".format(i + 1)) all_errors = [] for j, module in enumerate(self.nets.generator.decode): errors = [] for norm in [module.norm1, module.norm2]: if isinstance(norm, adaiw.AdaIN) and isinstance( norm.normalizer, adaiw.normalizer.Whitening): errors.append(norm.normalizer.last_error) all_errors.append(tuple(errors)) print(all_errors, file=f) if (i + 1) % args.print_color == 0: with torch.no_grad(): with open(os.path.join(args.notes_path, "last_color.txt"), "a+") as f: f.write("{} iterations: \n".format(i + 1)) for _, module in enumerate(self.nets.generator.decode): if hasattr(module.norm1, 'last_injected_stat'): diag1, tri1 = module.norm1.last_injected_stat.diagonal( ), module.norm1.last_injected_stat.tril(-1) print("Mean Diag:", diag1[0].mean().item(), "Std Diag:", diag1[0].std().item(), file=f) print("Mean Tril:", tri1[0].mean().item(), "Std Tril:", tri1[0].std().item(), file=f) diag2, tri2 = module.norm2.last_injected_stat.diagonal( ), module.norm2.last_injected_stat.tril(-1) print("Mean Diag:", diag2[0].mean().item(), "Std Diag:", diag2[0].std().item(), file=f) print("Mean Tril:", tri2[0].mean().item(), "Std Tril:", tri2[0].std().item(), file=f) print(file=f)
def train(self, loaders): args = self.args nets = self.nets nets_ema = self.nets_ema optims = self.optims dbg = 0 # fetch random validation images for debugging #fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train') #fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val') #inputs_val = next(fetcher_val) fc2_loader, test_loader = loaders fetcher = FC2Fetcher(fc2_loader, None, args.latent_dim) inputs_val = next(fetcher) s_trg_path = os.getcwd() + "/s_trg_out" self.grid = None if not os.path.exists(s_trg_path): os.makedirs(s_trg_path) # resume training if necessary if args.resume_iter > 0: self._load_checkpoint(args.resume_iter) # remember the initial value of ds weight initial_lambda_ds = args.lambda_ds print('Start training...') start_time = time.time() for i in range(args.resume_iter, args.total_iters): # fetch images and labels inputs = next(fetcher) x_real, x_real2, y_org = inputs.x_src, inputs.x2_src, inputs.y_src x_ref, y_trg = inputs.x_ref, inputs.y_ref flow, mask, = inputs.flow, inputs.mask z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2 masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None # train the discriminator d_loss, d_losses_latent = self.compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks) self._reset_grad() d_loss.backward() optims.discriminator.step() d_loss, d_losses_ref = self.compute_d_loss(nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks) self._reset_grad() d_loss.backward() optims.discriminator.step() # train the generator g_loss, g_losses_latent, s_trg_lat = self.compute_g_loss( nets, args, x_real, x_real2, flow, mask, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks) self._reset_grad() g_loss.backward() optims.generator.step() optims.mapping_network.step() optims.style_encoder.step() g_loss, g_losses_ref, s_trg_ref = self.compute_g_loss( nets, args, x_real, x_real2, flow, mask, y_org, y_trg, x_refs=[x_ref, None], masks=masks) self._reset_grad() g_loss.backward() optims.generator.step() # compute moving average of network parameters moving_average(nets.generator, nets_ema.generator, beta=0.999) moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999) moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999) # decay weight for diversity sensitive loss if args.lambda_ds > 0: args.lambda_ds -= (initial_lambda_ds / args.ds_iter) # print out log info if (i + 1) % args.print_every == 0 + dbg: elapsed = time.time() - start_time elapsed = str(datetime.timedelta(seconds=elapsed))[:-7] log = "Elapsed time [%s], Iteration [%i/%i], " % ( elapsed, i + 1, args.total_iters) all_losses = dict() for loss, prefix in zip([ d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref ], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']): for key, value in loss.items(): all_losses[prefix + key] = value all_losses['G/lambda_ds'] = args.lambda_ds log += ' '.join([ '%s: [%.4f]' % (key, value) for key, value in all_losses.items() ]) print(log) with open('losses.txt', 'a') as log_file: log_file.write(log + '\n') # generate images for debugging if (i + 1) % args.sample_every == 0 + dbg: os.makedirs(args.sample_dir, exist_ok=True) utils.debug_image(nets_ema, args, inputs=inputs_val, step=i + 1) # save model checkpoints if (i + 1) % args.save_every == 0 + dbg: self._save_checkpoint(step=i + 1) # compute FID and LPIPS if necessary if (i + 1) % args.eval_every == 0: calculate_metrics(nets_ema, args, i + 1, mode='latent') calculate_metrics(nets_ema, args, i + 1, mode='reference') # save sintel debug if (i + 1) % args.sample_every == 0 + dbg: self.debugSintel(nets, args, i + 1, test_loader) self.debugSintel(nets, args, i + 1, test_loader, inputs_val.x_ref) if (i + 1) % 100 == 0 + dbg: s_trg_out = torch.cat((s_trg_lat, s_trg_ref), 0).numpy() np.save(s_trg_path + "/s_trg_" + str(i + 1) + ".npy", s_trg_out) if dbg == 1: blah