def cyclgan_train(opt, cycle_gan: CycleGANModel, train_loader, writer_dict): cycle_gan.train() writer = writer_dict['writer'] total_iters = 0 t_data = 0.0 for epoch in trange(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 train_steps = writer_dict['train_steps'] for i, data in enumerate(train_loader): iter_start_time = time.time() if total_iters % opt.print_freq == 0: t_data = iter_start_time - iter_data_time total_iters += opt.batch_size epoch_iter += opt.batch_size cycle_gan.set_input(data) cycle_gan.optimize_parameters() if (i + 1) % opt.print_freq == 0: losses = cycle_gan.get_current_losses() t_comp = (time.time() - iter_start_time) message = "GAN: [Ep: %d/%d]" % (epoch, opt.n_epochs + opt.n_epochs_decay) message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % ( epoch_iter, len(train_loader), t_comp, t_data) for k, v in losses.items(): message += '[%s: %.3f]' % (k, v) tqdm.write(message) if (total_iters + 1) % opt.display_freq == 0: cycle_gan.compute_visuals() save_current_results(opt, cycle_gan.get_current_visuals(), train_steps) if (total_iters + 1) % opt.save_latest_freq == 0: tqdm.write( 'saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) save_suffix = 'latest' cycle_gan.save_networks(save_suffix) iter_data_time = time.time() if (epoch + 1) % opt.save_epoch_freq == 0: cycle_gan.save_networks('latest') cycle_gan.save_networks(epoch) tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) writer.add_scalars('Train/discriminator', { "A": float(cycle_gan.loss_D_A), "B": float(cycle_gan.loss_D_B), }, train_steps) writer.add_scalars('Train/generator', { "A": float(cycle_gan.loss_G_A), "B": float(cycle_gan.loss_G_B), }, train_steps) writer.add_scalars( 'Train/cycle', { "A": float(cycle_gan.loss_cycle_A), "B": float(cycle_gan.loss_cycle_B), }, train_steps) writer.add_scalars('Train/idt', { "A": float(cycle_gan.loss_idt_A), "B": float(cycle_gan.loss_idt_B), }, train_steps) writer_dict['train_steps'] += 1 cycle_gan.update_learning_rate()
def cyclgan_train(opt, cycle_gan: CycleGANModel, cycle_controller: CycleControllerModel, train_loader, g_loss_history: RunningStats, d_loss_history: RunningStats, writer_dict): cycle_gan.train() cycle_controller.eval() dynamic_reset = False writer = writer_dict['writer'] total_iters = 0 t_data = 0.0 for epoch in range(opt.shared_epoch): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 train_steps = writer_dict['train_steps'] for i, data in enumerate(train_loader): iter_start_time = time.time() if total_iters % opt.print_freq == 0: t_data = iter_start_time - iter_data_time total_iters += opt.batch_size epoch_iter += opt.batch_size cycle_controller.forward() cycle_gan.set_input(data) cycle_gan.optimize_parameters() g_loss_history.push(cycle_gan.loss_G.item()) d_loss_history.push(cycle_gan.loss_D_A.item() + cycle_gan.loss_D_B.item()) if (i + 1) % opt.print_freq == 0: losses = cycle_gan.get_current_losses() t_comp = (time.time() - iter_start_time) message = "GAN: [Ep: %d/%d]" % (epoch, opt.shared_epoch) message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % ( epoch_iter, len(train_loader), t_comp, t_data) for k, v in losses.items(): message += '[%s: %.3f]' % (k, v) tqdm.write(message) if (total_iters + 1) % opt.display_freq == 0: cycle_gan.compute_visuals() save_current_results(opt, cycle_gan.get_current_visuals(), train_steps) if g_loss_history.is_full(): if g_loss_history.get_var() < opt.dynamic_reset_threshold \ or d_loss_history.get_var() < opt.dynamic_reset_threshold: dynamic_reset = True tqdm.write("=> dynamic resetting triggered") g_loss_history.clear() d_loss_history.clear() return dynamic_reset if ( total_iters + 1 ) % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations tqdm.write( 'saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) save_suffix = 'latest' # cycle_gan.save_networks(train_steps) iter_data_time = time.time() if (epoch + 1) % opt.save_epoch_freq == 0: cycle_gan.save_networks('latest') # cycle_gan.save_networks(train_steps) tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) writer.add_scalars('Train/discriminator', { "A": float(cycle_gan.loss_D_A), "B": float(cycle_gan.loss_D_B), }, train_steps) writer.add_scalars('Train/generator', { "A": float(cycle_gan.loss_G_A), "B": float(cycle_gan.loss_G_B), }, train_steps) writer.add_scalars( 'Train/cycle', { "A": float(cycle_gan.loss_cycle_A), "B": float(cycle_gan.loss_cycle_B), }, train_steps) writer.add_scalars('Train/idt', { "A": float(cycle_gan.loss_idt_A), "B": float(cycle_gan.loss_idt_B), }, train_steps) writer_dict['train_steps'] += 1 return dynamic_reset