def train_gpu(rank, world_size, opt, dataset): signal.signal(signal.SIGINT, signal_handler) #to really kill the process if len(opt.gpu_ids) > 1: setup(rank, world_size, opt.ddp_port) dataloader = create_dataloader( opt, rank, dataset) # create a dataset given opt.dataset_mode and other options dataset_size = len(dataset) # get the number of images in the dataset. model = create_model( opt, rank) # create a model given opt.model and other options if hasattr(model, 'data_dependent_initialize'): data = next(iter(dataloader)) model.data_dependent_initialize(data) model.setup( opt) # regular setup: load and print networks; create schedulers if len(opt.gpu_ids) > 1: model.parallelize(rank) else: model.single_gpu() if rank == 0: visualizer = Visualizer( opt) # create a visualizer that display/save images and plots total_iters = 0 # the total number of training iterations if rank == 0: model.real_A_val, model.real_B_val = dataset.get_validation_set( opt.pool_size) model.real_A_val, model.real_B_val = model.real_A_val.to( model.device), model.real_B_val.to(model.device) if rank == 0 and opt.display_networks: data = next(iter(dataloader)) for path in model.save_networks_img(data): visualizer.display_img(path + '.png') for epoch in range( opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1 ): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq> epoch_start_time = time.time() # timer for entire epoch iter_data_time = time.time() # timer for data loading per iteration epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch if rank == 0: visualizer.reset( ) # reset the visualizer: make sure it saves the results to HTML at least once every epoch for i, data in enumerate( dataloader): # inner loop (minibatch) within one epoch iter_start_time = time.time( ) # timer for computation per iteration t_data_mini_batch = iter_start_time - iter_data_time model.set_input( data) # unpack data from dataloader and apply preprocessing model.optimize_parameters( ) # calculate loss functions, get gradients, update network weights t_comp = (time.time() - iter_start_time) / opt.batch_size batch_size = model.get_current_batch_size() * len(opt.gpu_ids) total_iters += batch_size epoch_iter += batch_size if rank == 0: if total_iters % opt.display_freq < batch_size: # display images on visdom and save images to a HTML file save_result = total_iters % opt.update_html_freq == 0 model.compute_visuals() visualizer.display_current_results( model.get_current_visuals(), epoch, save_result, params=model.get_display_param()) if total_iters % opt.print_freq < batch_size: # print training losses and save logging information to the disk losses = model.get_current_losses() visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data_mini_batch) if opt.display_id > 0: visualizer.plot_current_losses( epoch, float(epoch_iter) / dataset_size, losses) if total_iters % opt.save_latest_freq < batch_size: # cache our latest model every <save_latest_freq> iterations print( 'saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' model.save_networks(save_suffix) if total_iters % opt.fid_every < batch_size and opt.compute_fid: model.compute_fid(epoch, total_iters) if opt.display_id > 0: fids = model.get_current_fids() visualizer.plot_current_fid( epoch, float(epoch_iter) / dataset_size, fids) if total_iters % opt.D_accuracy_every < batch_size and opt.compute_D_accuracy: model.compute_D_accuracy() if opt.display_id > 0: accuracies = model.get_current_D_accuracies() visualizer.plot_current_D_accuracies( epoch, float(epoch_iter) / dataset_size, accuracies) if total_iters % opt.display_freq < batch_size and opt.APA: if opt.display_id > 0: p = model.get_current_APA_prob() visualizer.plot_current_APA_prob( epoch, float(epoch_iter) / dataset_size, p) iter_data_time = time.time() if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs if rank == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) model.save_networks('latest') model.save_networks(epoch) if rank == 0: print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) model.update_learning_rate( ) # update learning rates at the end of every epoch.
def main(): with open('./train/train_opt.pkl', mode='rb') as f: opt = pickle.load(f) opt.checkpoints_dir = './checkpoints/' opt.dataroot = './train' opt.no_flip = True opt.label_nc = 0 opt.batchSize = 2 print(opt) iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() dataset_size = len(data_loader) print('#training images = %d' % dataset_size) start_epoch, epoch_iter = 1, 0 total_steps = (start_epoch - 1) * dataset_size + epoch_iter display_delta = total_steps % opt.display_freq print_delta = total_steps % opt.print_freq save_delta = total_steps % opt.save_latest_freq best_loss = 999999 epoch_loss = 9999999999 model = create_model(opt) model = model.cuda() visualizer = Visualizer(opt) #niter = 20,niter_decay = 20 for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() if epoch != start_epoch: epoch_iter = epoch_iter % dataset_size for i, data in enumerate(dataset, start=epoch_iter): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize # whether to collect output images save_fake = total_steps % opt.display_freq == display_delta ############## Forward Pass ###################### losses, generated = model(Variable(data['label']), Variable(data['inst']), Variable(data['image']), Variable(data['feat']), infer=save_fake) # sum per device losses losses = [torch.mean(x) if not isinstance(x, int) else x for x in losses] loss_dict = dict(zip(model.loss_names, losses)) # calculate final loss scalar loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0) loss_DG = loss_D + loss_G ############### Backward Pass #################### # update generator weights model.optimizer_G.zero_grad() loss_G.backward() model.optimizer_G.step() # update discriminator weights model.optimizer_D.zero_grad() loss_D.backward() model.optimizer_D.step() # call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) ############## Display results and errors ########## ### print out errors if total_steps % opt.print_freq == print_delta: errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()} t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) visualizer.plot_current_errors(errors, total_steps) ### display output images if save_fake: visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), ('synthesized_image', util.tensor2im(generated.data[0])), ('real_image', util.tensor2im(data['image'][0]))]) visualizer.display_current_results(visuals, epoch, total_steps) ### save latest model if total_steps % opt.save_latest_freq == save_delta and loss_DG<best_loss: best_loss = loss_DG print('saving the latest model (epoch %d, total_steps %d ,total loss %g)' % (epoch, total_steps,loss_DG.item())) model.save('latest') np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') if epoch_iter >= dataset_size: break # end of epoch iter_end_time = time.time() print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) ### save model for this epoch if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d ' % (epoch, total_steps)) model.save('latest') model.save(epoch) np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d') ### instead of only training the local enhancer, train the entire network after certain iterations if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): model.update_fixed_params() ### linearly decay learning rate after certain iterations if epoch > opt.niter: model.update_learning_rate() torch.cuda.empty_cache()
def __init__(self, celebA_loader, rafd_loader, config): # Data loader self.celebA_loader = celebA_loader self.rafd_loader = rafd_loader self.visualizer = Visualizer() # Model hyper-parameters self.c_dim = config.c_dim self.s_dim = config.s_dim self.c2_dim = config.c2_dim self.image_size = config.image_size self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.g_repeat_num = config.g_repeat_num self.d_repeat_num = config.d_repeat_num self.d_train_repeat = config.d_train_repeat # Hyper-parameteres self.lambda_cls = config.lambda_cls self.lambda_rec = config.lambda_rec self.lambda_gp = config.lambda_gp self.lambda_s = config.lambda_s self.g_lr = config.g_lr self.d_lr = config.d_lr self.a_lr = config.a_lr self.beta1 = config.beta1 self.beta2 = config.beta2 # Criterion self.criterion_s = CrossEntropyLoss2d(size_average=True).cuda() # Training settings self.dataset = config.dataset self.num_epochs = config.num_epochs self.num_epochs_decay = config.num_epochs_decay self.num_iters = config.num_iters self.num_iters_decay = config.num_iters_decay self.batch_size = config.batch_size self.use_tensorboard = config.use_tensorboard self.pretrained_model = config.pretrained_model # Test settings self.test_model = config.test_model self.config = config # Path self.log_path = config.log_path self.sample_path = config.sample_path self.model_save_path = config.model_save_path self.result_path = config.result_path # Step size self.log_step = config.log_step self.visual_step = self.log_step self.sample_step = config.sample_step self.model_save_step = config.model_save_step # Build tensorboard if use self.build_model() if self.use_tensorboard: self.build_tensorboard() # Start with trained model if self.pretrained_model: self.load_pretrained_model()
def main(): opt = TestOptions().parse(save=False) opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() visualizer = Visualizer(opt) # create website web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML( web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) # test if not opt.engine and not opt.onnx: model = create_model(opt) if opt.data_type == 16: model.half() elif opt.data_type == 8: model.type(torch.uint8) if opt.verbose: print(model) else: from run_engine import run_trt_engine, run_onnx for i, data in enumerate(dataset): if i >= opt.how_many: break if opt.data_type == 16: data['label'] = data['label'].half() data['inst'] = data['inst'].half() elif opt.data_type == 8: data['label'] = data['label'].uint8() data['inst'] = data['inst'].uint8() if opt.export_onnx: print("Exporting to ONNX: ", opt.export_onnx) assert opt.export_onnx.endswith( "onnx"), "Export model file should end with .onnx" torch.onnx.export(model, [data['label'], data['inst']], opt.export_onnx, verbose=True) exit(0) minibatch = 1 if opt.engine: generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) elif opt.onnx: generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) else: generated = model.inference(data['label'], data['inst'], data['image']) visuals = OrderedDict([ ('input_label', util.tensor2label(data['label'][0], opt.label_nc)), ('synthesized_image', util.tensor2im(generated.data[0])) ]) img_path = data['path'] print('process image... %s' % img_path) visualizer.save_images(webpage, visuals, img_path) webpage.save()
def train(): opt = vars(TrainOptions().parse()) # Initialize dataset:==============================================================================================# data_loader = create_dataloader(**opt) dataset_size = len(data_loader) print(f'Number of training videos = {dataset_size}') # Initialize models:===============================================================================================# models = prepare_models(**opt) model_g, model_d, flow_net, optimizer_g, optimizer_d, optimizer_d_t = create_optimizer( models, **opt) # Set parameters:==================================================================================================# n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \ start_epoch, epoch_iter, print_freq, total_steps, iter_path = init_params(model_g, model_d, data_loader, **opt) visualizer = Visualizer(**opt) # Initialize loss list:============================================================================================# losses_G = [] losses_D = [] # Training start:==================================================================================================# for epoch in range(start_epoch, opt['niter'] + opt['niter_decay'] + 1): epoch_start_time = time.time() for idx, video in enumerate(data_loader, start=epoch_iter): if not total_steps % print_freq: iter_start_time = time.time() total_steps += opt['batch_size'] epoch_iter += opt['batch_size'] # whether to collect output images save_fake = total_steps % opt['display_freq'] == 0 fake_B_prev_last = None real_B_all, fake_B_all, flow_ref_all, conf_ref_all = None, None, None, None # all real/generated frames so far if opt['sparse_D']: real_B_all, fake_B_all, flow_ref_all, conf_ref_all = [ None ] * t_scales, [None] * t_scales, [None] * t_scales, [ None ] * t_scales frames_all = real_B_all, fake_B_all, flow_ref_all, conf_ref_all for i, (input_A, input_B) in enumerate(VideoSeq(**video, **opt)): # Forward Pass:========================================================================================# # Generator:===========================================================================================# fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = model_g( input_A, input_B, fake_B_prev_last) # Discriminator:=======================================================================================# # individual frame discriminator:==============================# # the collection of previous and current real frames real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:] # reference flows and confidences flow_ref, conf_ref = flow_net(real_B, real_B_prev) fake_B_prev = model_g.compute_fake_B_prev( real_B_prev, fake_B_prev_last, fake_B) fake_B_prev_last = fake_B_last losses = model_d( 0, reshape([ real_B, fake_B, fake_B_raw, real_A, real_B_prev, fake_B_prev, flow, weight, flow_ref, conf_ref ])) losses = [ torch.mean(x) if x is not None else 0 for x in losses ] loss_dict = dict(zip(model_d.loss_names, losses)) # Temporal Discriminator:======================================# # get skipped frames for each temporal scale frames_all, frames_skipped = \ model_d.get_all_skipped_frames(frames_all, real_B, fake_B, flow_ref, conf_ref, t_scales, tD, video.n_frames_load, i, flow_net) # run discriminator for each temporal scale:===================# loss_dict_T = [] for s in range(t_scales): if frames_skipped[0][s] is not None: losses = model_d(s + 1, [ frame_skipped[s] for frame_skipped in frames_skipped ]) losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] loss_dict_T.append( dict(zip(model_d.loss_names_T, losses))) # Collect losses:==============================================# loss_G, loss_D, loss_D_T, t_scales_act = model_d.get_losses( loss_dict, loss_dict_T, t_scales) losses_G.append(loss_G.item()) losses_D.append(loss_D.item()) ################################## Backward Pass ########################### # Update generator weights loss_backward(loss_G, optimizer_g) # update individual discriminator weights loss_backward(loss_D, optimizer_d) # update temporal discriminator weights for s in range(t_scales_act): loss_backward(opt, loss_D_T[s], optimizer_d_t[s]) # the first generated image in this sequence if i == 0: fake_B_first = fake_B[0, 0] # Display results and errors:==============================================# # Print out errors:================================================# if total_steps % print_freq == 0: t = (time.time() - iter_start_time) / print_freq errors = {k: v.data.item() if not isinstance(v, int) \ else v for k, v in loss_dict.items()} for s in range(len(loss_dict_T)): errors.update({k + str(s): v.data.item() \ if not isinstance(v, int) \ else v for k, v in loss_dict_T[s].items()}) visualizer.print_current_errors(epoch, epoch_iter, errors, t) visualizer.plot_current_errors(errors, total_steps) # Display output images:===========================================# if save_fake: visuals = util.save_all_tensors(opt, real_A, fake_B, fake_B_first, fake_B_raw, real_B, flow_ref, conf_ref, flow, weight, model_d) visualizer.display_current_results(visuals, epoch, total_steps) # Save latest model:===============================================# save_models(epoch, epoch_iter, total_steps, visualizer, iter_path, model_g, model_d, **opt) if epoch_iter > dataset_size - opt['batch_size']: epoch_iter = 0 break # End of epoch:========================================================# visualizer.vis_print( f'End of epoch {epoch} / {opt["niter"] + opt["niter_decay"]} \t' f' Time Taken: {time.time() - epoch_start_time} sec') # save model for this epoch and update model params:=================# save_models(epoch, epoch_iter, total_steps, visualizer, iter_path, model_g, model_d, end_of_epoch=True, **opt) update_models(epoch, model_g, model_d, data_loader, **opt) from matplotlib import pyplot as plt plt.switch_backend('agg') print("Generator Loss: %f." % losses_G[-1]) print("Discriminator Loss: %f." % losses_D[-1]) # Plot Losses plt.plot(losses_G, '-b', label='losses_G') plt.plot(losses_D, '-r', label='losses_D') # plt.plot(losses_D_T, '-r', label='losses_D_T') plot_name = 'checkpoints/' + opt['name'] + '/losses_plot.png' plt.savefig(plot_name) plt.close()
# data_loader = getLoader(opt) # dataset_size = len(data_loader) # print('#training images = %d' % dataset_size) model = CasGAN() model.initialize(opt) # visualizer = Visualizer(opt) # total_steps = 0 # fixed_noise = torch.FloatTensor(opt.valBatchSize, opt.hidden_size, 1, 1).uniform_(-1, 1) # fixed_noise = fixed_noise.cuda() image = rtf(opt.img_path, opt) image = image.unsqueeze(0) print(image.size()) n_fake, c_fake = model.encode(image) vis = Visualizer(opt) c_fake_m = c_fake.clone() c_fake_m.data[0, 0, 0, 0] = 1 vis.plot_current_label((c_fake, c_fake_m), 1) img_generated = model.decode(n_fake, c_fake) img_generated = tensor2im(img_generated) from matplotlib import pyplot as plt plt.imshow(img_generated, interpolation='nearest') plt.show() # img_generated = model.decode(n_fake,c_fake_m) # img_generated = tensor2im(img_generated) # from matplotlib import pyplot as plt # plt.imshow(img_generated, interpolation='nearest') # plt.show()
def train(): opt = TrainOptions().parse() if opt.debug: opt.display_freq = 1 opt.print_freq = 1 opt.nThreads = 1 # Initialize dataset:======================================================# data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() dataset_size = len(data_loader) print('Number of training videos = %d' % dataset_size) # Initialize models:=======================================================# models = prepare_models(opt) modelG, modelD, flowNet, optimizer_G, optimizer_D, optimizer_D_T = \ create_optimizer(opt, models) # Set parameters:==========================================================# n_gpus, tG, tD, tDB, s_scales, t_scales, input_nc, output_nc, \ start_epoch, epoch_iter, print_freq, total_steps, iter_path = \ init_params(opt, modelG, modelD, data_loader) visualizer = Visualizer(opt) # Initialize loss list:====================================================# losses_G = [] losses_D = [] losses_D_T = [] losses_t_scales = [] # Real training starts here:===============================================# for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() for idx, data in enumerate(dataset, start=epoch_iter): if total_steps % print_freq == 0: iter_start_time = time.time() total_steps += opt.batch_size epoch_iter += opt.batch_size # whether to collect output images save_fake = total_steps % opt.display_freq == 0 n_frames_total, n_frames_load, t_len = \ data_loader.dataset.init_data_params(data, n_gpus, tG) fake_B_prev_last, frames_all = data_loader.dataset.init_data( t_scales) for i in range(0, n_frames_total, n_frames_load): input_A, input_B, input_C, inst_A = \ data_loader.dataset.prepare_data(data, i, input_nc, output_nc) ############################### Forward Pass ############################### ####### Generator:=========================================================# fake_B, fake_B_raw, flow, weight, real_A, real_Bp, fake_B_last = \ modelG(input_A, input_B, inst_A, fake_B_prev_last) ####### Discriminator:=====================================================# # individual frame discriminator:==============================# # the collection of previous and current real frames real_B_prev, real_B = real_Bp[:, :-1], real_Bp[:, 1:] # reference flows and confidences flow_ref, conf_ref = flowNet(real_B, real_B_prev) fake_B_prev = modelG.module.compute_fake_B_prev( real_B_prev, fake_B_prev_last, fake_B) fake_B_prev_last = fake_B_last losses = modelD( 0, reshape([ real_B, fake_B, fake_B_raw, real_A, real_B_prev, fake_B_prev, flow, weight, flow_ref, conf_ref, input_C ])) losses = [ torch.mean(x) if x is not None else 0 for x in losses ] loss_dict = dict(zip(modelD.module.loss_names, losses)) # Temporal Discriminator:======================================# # get skipped frames for each temporal scale frames_all, frames_skipped = \ modelD.module.get_all_skipped_frames(frames_all, real_B, fake_B, flow_ref, conf_ref, t_scales, tD, n_frames_load, i, flowNet) # run discriminator for each temporal scale:===================# loss_dict_T = [] for s in range(t_scales): if frames_skipped[0][s] is not None: losses = modelD(s + 1, [ frame_skipped[s] for frame_skipped in frames_skipped ]) losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] loss_dict_T.append( dict(zip(modelD.module.loss_names_T, losses))) # Collect losses:==============================================# loss_G, loss_D, loss_D_T, t_scales_act = \ modelD.module.get_losses(loss_dict, loss_dict_T, t_scales) losses_G.append(loss_G.item()) losses_D.append(loss_D.item()) ################################## Backward Pass ########################### # Update generator weights loss_backward(opt, loss_G, optimizer_G) # update individual discriminator weights loss_backward(opt, loss_D, optimizer_D) # update temporal discriminator weights for s in range(t_scales_act): loss_backward(opt, loss_D_T[s], optimizer_D_T[s]) # the first generated image in this sequence if i == 0: fake_B_first = fake_B[0, 0] if opt.debug: call([ "nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free" ]) # Display results and errors:==============================================# # Print out errors:================================================# if total_steps % print_freq == 0: t = (time.time() - iter_start_time) / print_freq errors = {k: v.data.item() if not isinstance(v, int) \ else v for k, v in loss_dict.items()} for s in range(len(loss_dict_T)): errors.update({k + str(s): v.data.item() \ if not isinstance(v, int) \ else v for k, v in loss_dict_T[s].items()}) visualizer.print_current_errors(epoch, epoch_iter, errors, t) visualizer.plot_current_errors(errors, total_steps) # Display output images:===========================================# if save_fake: visuals = util.save_all_tensors(opt, real_A, fake_B, fake_B_first, fake_B_raw, real_B, flow_ref, conf_ref, flow, weight, modelD) visualizer.display_current_results(visuals, epoch, total_steps) # Save latest model:===============================================# save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, modelG, modelD) if epoch_iter > dataset_size - opt.batch_size: epoch_iter = 0 break # End of epoch:========================================================# visualizer.vis_print('End of epoch %d / %d \t Time Taken: %d sec' % \ (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) ### save model for this epoch and update model params:=================# save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, modelG, modelD, end_of_epoch=True) update_models(opt, epoch, modelG, modelD, data_loader) from matplotlib import pyplot as plt plt.switch_backend('agg') print("Generator Loss: %f." % losses_G[-1]) print("Discriminator loss: %f." % losses_D[-1]) #Plot Losses plt.plot(losses_G, '-b', label='losses_G') plt.plot(losses_D, '-r', label='losses_D') # plt.plot(losses_D_T, '-r', label='losses_D_T') plot_name = 'checkpoints/' + opt.name + '/losses_plot.png' plt.savefig(plot_name) plt.close()
def main(): # 입력 X, return X iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') # 반복 경로 받아오기 data_loader = CreateDataLoader(opt) # option에 해당하는 data_loader 생성 dataset = data_loader.load_data() # dataset을 data_loader로부터 받아온다. dataset_size = len(data_loader) # dataset의 사이즈를 지정 print('#training images = %d' % dataset_size) start_epoch, epoch_iter = 1, 0 total_steps = (start_epoch - 1) * dataset_size + epoch_iter display_delta = total_steps % opt.display_freq print_delta = total_steps % opt.print_freq save_delta = total_steps % opt.save_latest_freq # delta 값들 지정 model = create_model(opt) # model = model.cuda() visualizer = Visualizer(opt) # 현재 option에 해당하는 훈련 과정 출력 for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): # 총 40번 반복 epoch_start_time = time.time() if epoch != start_epoch: epoch_iter = epoch_iter % dataset_size for i, data in enumerate(dataset, start=epoch_iter): iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize # whether to collect output images save_fake = total_steps % opt.display_freq == display_delta ############## Forward Pass ###################### losses, generated = model(Variable(data['label']), Variable(data['inst']), Variable(data['image']), Variable(data['feat']), infer=save_fake) # sum per device losses losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] loss_dict = dict(zip(model.loss_names, losses)) # calculate final loss scalar loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 loss_G = loss_dict['G_GAN'] + loss_dict.get( 'G_GAN_Feat', 0) + loss_dict.get('G_VGG', 0) ############### Backward Pass #################### # update generator weights model.optimizer_G.zero_grad() loss_G.backward() model.optimizer_G.step() # update discriminator weights model.optimizer_D.zero_grad() loss_D.backward() model.optimizer_D.step() ############## Display results and errors ########## ### print out errors if total_steps % opt.print_freq == print_delta: errors = { k: v.data if not isinstance(v, int) else v for k, v in loss_dict.items() } t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t) visualizer.plot_current_errors(errors, total_steps) ### display output images if save_fake: visuals = OrderedDict([ ('input_label', util.tensor2label(data['label'][0], opt.label_nc)), ('synthesized_image', util.tensor2im(generated.data[0])), ('real_image', util.tensor2im(data['image'][0])) ]) visualizer.display_current_results(visuals, epoch, total_steps) ### save latest model if total_steps % opt.save_latest_freq == save_delta: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.save('latest') np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') if epoch_iter >= dataset_size: break # end of epoch print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) ### save model for this epoch if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save('latest') model.save(epoch) np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d') ### instead of only training the local enhancer, train the entire network after certain iterations if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): model.update_fixed_params() ### linearly decay learning rate after certain iterations if epoch > opt.niter: model.update_learning_rate() torch.cuda.empty_cache()
def train(): opt = TrainOptions().parse() iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') if opt.continue_train: try: start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int) except: start_epoch, epoch_iter = 1, 0 # compute resume lr if start_epoch > opt.niter: lrd_unit = opt.lr / opt.niter_decay resume_lr = opt.lr - (start_epoch - opt.niter) * lrd_unit opt.lr = resume_lr print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) else: start_epoch, epoch_iter = 1, 0 opt.print_freq = lcm(opt.print_freq, opt.batchSize) if opt.debug: opt.display_freq = 2 opt.print_freq = 2 opt.niter = 3 opt.niter_decay = 0 opt.max_dataset_size = 1 opt.valSize = 1 ## Loading data # train data data_loader = CreateDataLoader(opt, isVal=False) dataset = data_loader.load_data() dataset_size = len(data_loader) print('# training images = %d' % dataset_size) # validation data data_loader = CreateDataLoader(opt, isVal=True) valset = data_loader.load_data() print('# validation images = %d' % len(data_loader)) ## Loading model model = create_model(opt) visualizer = Visualizer(opt) if opt.fp16: from apex import amp model, [optimizer_G, optimizer_D ] = amp.initialize(model, [model.optimizer_G, model.optimizer_D], opt_level='O1') model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) else: optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D total_steps = (start_epoch - 1) * dataset_size + epoch_iter display_delta = total_steps % opt.display_freq print_delta = total_steps % opt.print_freq save_delta = total_steps % opt.save_latest_freq for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() if epoch != start_epoch: # epoch_iter = epoch_iter % dataset_size epoch_iter = 0 for i, data in enumerate(dataset, start=epoch_iter): if total_steps % opt.print_freq == print_delta: iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize # whether to collect output images save_fake = total_steps % opt.display_freq == display_delta ############## Forward Pass ###################### model = model.train() losses, generated, metrics = model(data['A'], data['B'], data['geometry'], infer=False) # sum per device losses and metrics losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] metric_dict = {k: torch.mean(v) for k, v in metrics.items()} loss_dict = dict(zip(model.module.loss_names, losses)) # calculate final loss scalar loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 loss_G = loss_dict['G_GAN'] + loss_dict.get( 'G_GAN_Feat', 0) + opt.vgg_weight * loss_dict.get('G_VGG', 0) ############### Backward Pass #################### # update generator weights optimizer_G.zero_grad() if opt.fp16: with amp.scale_loss(loss_G, optimizer_G) as scaled_loss: scaled_loss.backward() else: loss_G.backward() optimizer_G.step() # update discriminator weights optimizer_D.zero_grad() if opt.fp16: with amp.scale_loss(loss_D, optimizer_D) as scaled_loss: scaled_loss.backward() else: loss_D.backward() optimizer_D.step() ############## Display results and errors ########## ### print out errors if total_steps % opt.print_freq == print_delta: errors = { k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items() } metrics_ = { k: v.data.item() if not isinstance(v, int) else v for k, v in metric_dict.items() } t = (time.time() - iter_start_time) / opt.print_freq visualizer.print_current_errors(epoch, epoch_iter, errors, t) visualizer.plot_current_errors(errors, total_steps) visualizer.print_current_metrics(epoch, epoch_iter, metrics_, t) visualizer.plot_current_metrics(metrics_, total_steps) #call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) ### display output images if save_fake: if opt.task_type == 'specular': visuals = OrderedDict([ ('albedo', util.tensor2im(data['A'][0])), ('generated', util.tensor2im_exr(generated.data[0], type=1)), ('GT', util.tensor2im_exr(data['B'][0], type=1)) ]) elif opt.task_type == 'low': visuals = OrderedDict([ ('albedo', util.tensor2im(data['A'][0])), ('generated', util.tensor2im_exr(generated.data[0], type=2)), ('GT', util.tensor2im_exr(data['B'][0], type=2)) ]) elif opt.task_type == 'high': visuals = OrderedDict([ ('albedo', util.tensor2im(data['A'][0])), ('generated', util.tensor2im_exr(generated.data[0], type=3)), ('GT', util.tensor2im_exr(data['B'][0], type=3)) ]) visualizer.display_current_results(visuals, epoch, total_steps) ### save latest model if total_steps % opt.save_latest_freq == save_delta: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.module.save('latest') np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') if epoch_iter >= dataset_size: break # end of epoch print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) ########################################################################################### # validation at the end of each epoch val_start_time = time.time() metrics_val = [] for _, val_data in enumerate(valset): model = model.eval() # model.half() generated, metrics = model(val_data['A'], val_data['B'], val_data['geometry'], infer=True) metric_dict = {k: torch.mean(v) for k, v in metrics.items()} metrics_ = { k: v.data.item() if not isinstance(v, int) else v for k, v in metric_dict.items() } metrics_val.append(metrics_) # Print out losses metrics_val = visualizer.mean4dict(metrics_val) t = (time.time() - val_start_time) / opt.print_freq visualizer.print_current_metrics(epoch, epoch_iter, metrics_val, t, isVal=True) visualizer.plot_current_metrics(metrics_val, total_steps, isVal=True) # visualization if opt.task_type == 'specular': visuals = OrderedDict([ ('albedo', util.tensor2im(val_data['A'][0])), ('generated', util.tensor2im_exr(generated.data[0], type=1)), ('GT', util.tensor2im_exr(val_data['B'][0], type=1)) ]) if opt.task_type == 'low': visuals = OrderedDict([ ('albedo', util.tensor2im(val_data['A'][0])), ('generated', util.tensor2im_exr(generated.data[0], type=2)), ('GT', util.tensor2im_exr(val_data['B'][0], type=2)) ]) if opt.task_type == 'high': visuals = OrderedDict([ ('albedo', util.tensor2im(val_data['A'][0])), ('generated', util.tensor2im_exr(generated.data[0], type=3)), ('GT', util.tensor2im_exr(val_data['B'][0], type=3)) ]) visualizer.display_current_results(visuals, epoch, epoch, isVal=True) ########################################################################################### ### save model for this epoch if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.module.save('latest') model.module.save(epoch) np.savetxt(iter_path, (epoch + 1, 0), delimiter=',', fmt='%d') ### instead of only training the local enhancer, train the entire network after certain iterations if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): model.module.update_fixed_params() ### linearly decay learning rate after certain iterations if epoch > opt.niter: model.module.update_learning_rate()