def train(self, data_loader, stage=1): if stage == 1: netG, netD, start_epoch = self.load_network_stageI() else: netG, netD, start_epoch = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerG, optimizerD = self.load_optimizers(netG, netD) count = 0 for epoch in range(start_epoch, self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.float().cuda() ###################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ###################################################### # (3) Update D network ###################################################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ###################################################### # (2) Update G network ###################################################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) save_optimizer(optimizerG, optimizerD, self.model_dir) save_model(netG, netD, self.max_epoch, self.model_dir) self.summary_writer.close()
def train(self, data_loader, stage=1): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM # 100 batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = \ Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH optimizerD = \ optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 detectron = Detectron() for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr #print('check 0') for i, data in enumerate(data_loader): ###################################################### # (1) Prepare training data ###################################################### #print('check 1') real_img_cpu, txt_embedding, caption = data caption = np.moveaxis(np.array(caption), 1, 0) #print('check 2') real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) #print('check 3') if cfg.CUDA: real_imgs = real_imgs.cuda() txt_embedding = txt_embedding.cuda() ####################################################### # (2) Generate fake images ###################################################### #print(real_imgs.size()) noise.data.normal_(0, 1) inputs = (txt_embedding, noise) _, fake_imgs, mu, logvar = \ nn.parallel.data_parallel(netG, inputs, self.gpus) ############################ # (3) Update D network ########################### netD.zero_grad() errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, mu, self.gpus) errD.backward() optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() errG = compute_generator_loss(netD, fake_imgs, real_labels, mu, self.gpus) kl_loss = KL_loss(mu, logvar) fake_img = fake_imgs.cpu().detach().numpy() #print(fake_img.shape) det_obj_list = detectron.get_labels(fake_img) fake_l = Variable(get_ohe(det_obj_list)).cuda() real_l = Variable(get_ohe(caption)).cuda() det_loss = nn.SmoothL1Loss()(fake_l, real_l) errG_total = det_loss + errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count = count + 1 if i % 100 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) summary_DET = summary.scalar('det_loss', det_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) self.summary_writer.add_summary(summary_DET, count) # save the image result for each epoch inputs = (txt_embedding, fixed_noise) lr_fake, fake, _, _ = \ nn.parallel.data_parallel(netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, epoch, self.model_dir) # save_model(netG, netD, self.max_epoch, self.model_dir) # self.summary_writer.close()
def train(self, data_loader, stage=1, max_objects=3): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) # with torch.no_grad(): fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) count = 0 for epoch in range(self.max_epoch): start_t = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr for i, data in enumerate(data_loader, 0): ###################################################### # (1) Prepare training data ###################################################### real_img_cpu, bbox, label, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() if cfg.STAGE == 1: bbox = bbox.cuda() elif cfg.STAGE == 2: bbox = [bbox[0].cuda(), bbox[1].cuda()] label = label.cuda() txt_embedding = txt_embedding.cuda() if cfg.STAGE == 1: bbox = bbox.view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices = compute_transformation_matrix(bbox) transf_matrices = transf_matrices.view( real_imgs.shape[0], max_objects, 2, 3) elif cfg.STAGE == 2: _bbox = bbox[0].view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) _bbox = bbox[1].view(-1, 4) transf_matrices_inv_s2 = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv_s2 = transf_matrices_inv_s2.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices_s2 = compute_transformation_matrix(_bbox) transf_matrices_s2 = transf_matrices_s2.view( real_imgs.shape[0], max_objects, 2, 3) # produce one-hot encodings of the labels _labels = label.long() # remove -1 to enable one-hot converting _labels[_labels < 0] = 80 # label_one_hot = torch.cuda.FloatTensor(noise.shape[0], max_objects, 81).fill_(0) label_one_hot = torch.FloatTensor(noise.shape[0], max_objects, 81).fill_(0) label_one_hot = label_one_hot.scatter_(2, _labels, 1).float() ####################################################### # (2) Generate fake images ###################################################### noise.data.normal_(0, 1) if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot) ############################ # (3) Update D network ########################### netD.zero_grad() if cfg.STAGE == 1: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) errD.backward(retain_graph=True) optimizerD.step() ############################ # (2) Update G network ########################### netG.zero_grad() if cfg.STAGE == 1: errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) kl_loss = KL_loss(mu, logvar) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL errG_total.backward() optimizerG.step() count += 1 if i % 500 == 0: summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # self.summary_writer.close()
def train(self, data_loader, stage=1, max_objects=3): if stage == 1: netG, netD = self.load_network_stageI() else: netG, netD = self.load_network_stageII() nz = cfg.Z_DIM batch_size = self.batch_size noise = Variable(torch.FloatTensor(batch_size, nz)) # with torch.no_grad(): fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() generator_lr = cfg.TRAIN.GENERATOR_LR discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH netG_para = [] for p in netG.parameters(): if p.requires_grad: netG_para.append(p) optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) #### startpoint = -1 if cfg.NET_G != '': state_dict = torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) optimizerD.load_state_dict(state_dict["optimD"]) optimizerG.load_state_dict(state_dict["optimG"]) startpoint = state_dict["epoch"] print(startpoint) print('Load Optim and optimizers as : ', cfg.NET_G) #### count = 0 drive_count = 0 for epoch in range(startpoint + 1, self.max_epoch): print('epoch : ', epoch, ' drive_count : ', drive_count) epoch_start_time = time.time() print(epoch) start_t = time.time() start_t500 = time.time() if epoch % lr_decay_step == 0 and epoch > 0: generator_lr *= 0.5 for param_group in optimizerG.param_groups: param_group['lr'] = generator_lr discriminator_lr *= 0.5 for param_group in optimizerD.param_groups: param_group['lr'] = discriminator_lr time_to_i = time.time() for i, data in enumerate(data_loader, 0): # if i >= 3360 : # print ('Last Batches : ' , i) # if i < 10 : # print ('first Batches : ' , i) # if i == 0 : # print ('Startig! Batch ',i,'from total of 2070' ) # if i % 10 == 0 and i!=0: # end_t500 = time.time() # print ('Batch Number : ' , i ,' ||||| Toatal Time : ' , (end_t500 - start_t500)) # start_t500 = time.time() ###################################################### # (1) Prepare training data # if i < 10 : # print (" (1) Prepare training data for batch : " , i) ###################################################### #print ("Prepare training data for batch : " , i) real_img_cpu, bbox, label, txt_embedding = data real_imgs = Variable(real_img_cpu) txt_embedding = Variable(txt_embedding) if cfg.CUDA: real_imgs = real_imgs.cuda() if cfg.STAGE == 1: bbox = bbox.cuda() elif cfg.STAGE == 2: bbox = [bbox[0].cuda(), bbox[1].cuda()] label = label.cuda() txt_embedding = txt_embedding.cuda() if cfg.STAGE == 1: bbox = bbox.view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices = compute_transformation_matrix(bbox) transf_matrices = transf_matrices.view( real_imgs.shape[0], max_objects, 2, 3) elif cfg.STAGE == 2: _bbox = bbox[0].view(-1, 4) transf_matrices_inv = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv = transf_matrices_inv.view( real_imgs.shape[0], max_objects, 2, 3) _bbox = bbox[1].view(-1, 4) transf_matrices_inv_s2 = compute_transformation_matrix_inverse( _bbox) transf_matrices_inv_s2 = transf_matrices_inv_s2.view( real_imgs.shape[0], max_objects, 2, 3) transf_matrices_s2 = compute_transformation_matrix(_bbox) transf_matrices_s2 = transf_matrices_s2.view( real_imgs.shape[0], max_objects, 2, 3) # produce one-hot encodings of the labels _labels = label.long() # remove -1 to enable one-hot converting _labels[_labels < 0] = 80 if cfg.CUDA: label_one_hot = torch.cuda.FloatTensor( noise.shape[0], max_objects, 81).fill_(0) else: label_one_hot = torch.FloatTensor(noise.shape[0], max_objects, 81).fill_(0) label_one_hot = label_one_hot.scatter_(2, _labels, 1).float() ####################################################### # # (2) Generate fake images # if i < 10 : # print ("(2)Generate fake images") ###################################################### noise.data.normal_(0, 1) if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) if cfg.CUDA: _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) else: print('Hiiiiiiiiiiii') _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot) # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot) ############################ # # (3) Update D network # if i < 10 : # print("(3) Update D network") ########################### netD.zero_grad() if cfg.STAGE == 1: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: errD, errD_real, errD_wrong, errD_fake = \ compute_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) errD.backward(retain_graph=True) optimizerD.step() ############################ # # (4) Update G network # if i < 10 : # print ("(4) Update G network") ########################### netG.zero_grad() # if i < 10 : # print ("netG.zero_grad") if cfg.STAGE == 1: errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices, transf_matrices_inv, mu, self.gpus) elif cfg.STAGE == 2: # if i < 10 : # print ("cgf.STAGE = " , cfg.STAGE) errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, transf_matrices_s2, transf_matrices_inv_s2, mu, self.gpus) # if i < 10 : # print("errG : ",errG) kl_loss = KL_loss(mu, logvar) # if i < 10 : # print ("kl_loss = " , kl_loss) errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL # if i < 10 : # print (" errG_total = " , errG_total ) errG_total.backward() # if i < 10 : # print ("errG_total.backward() ") optimizerG.step() # if i < 10 : # print ("optimizerG.step() " ) #print (" i % 500 == 0 : " , i % 500 == 0 ) end_t = time.time() #print ("batch time : " , (end_t - start_t)) if i % 500 == 0: #print (" i % 500 == 0" , i % 500 == 0 ) count += 1 summary_D = summary.scalar('D_loss', errD.item()) summary_D_r = summary.scalar('D_loss_real', errD_real) summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) summary_D_f = summary.scalar('D_loss_fake', errD_fake) summary_G = summary.scalar('G_loss', errG.item()) summary_KL = summary.scalar('KL_loss', kl_loss.item()) print('epoch : ', epoch) print('count : ', count) print(' i : ', i) print('Time to i : ', time.time() - time_to_i) time_to_i = time.time() print('D_loss : ', errD.item()) print('D_loss_real : ', errD_real) print('D_loss_wrong : ', errD_wrong) print('D_loss_fake : ', errD_fake) print('G_loss : ', errG.item()) print('KL_loss : ', kl_loss.item()) print('generator_lr : ', generator_lr) print('discriminator_lr : ', discriminator_lr) print('lr_decay_step : ', lr_decay_step) self.summary_writer.add_summary(summary_D, count) self.summary_writer.add_summary(summary_D_r, count) self.summary_writer.add_summary(summary_D_w, count) self.summary_writer.add_summary(summary_D_f, count) self.summary_writer.add_summary(summary_G, count) self.summary_writer.add_summary(summary_KL, count) # save the image result for each epoch with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) if cfg.CUDA: lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) else: lr_fake, fake, _, _, _ = netG( txt_embedding, noise, transf_matrices_inv, label_one_hot) save_img_results(real_img_cpu, fake, epoch, self.image_dir) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) if i % 100 == 0: drive_count += 1 self.drive_summary_writer.add_summary( summary_D, drive_count) self.drive_summary_writer.add_summary( summary_D_r, drive_count) self.drive_summary_writer.add_summary( summary_D_w, drive_count) self.drive_summary_writer.add_summary( summary_D_f, drive_count) self.drive_summary_writer.add_summary( summary_G, drive_count) self.drive_summary_writer.add_summary( summary_KL, drive_count) #print (" with torch.no_grad(): " ) with torch.no_grad(): if cfg.STAGE == 1: inputs = (txt_embedding, noise, transf_matrices_inv, label_one_hot) elif cfg.STAGE == 2: #print (" cfg.STAGE == 2: " , cfg.STAGE == 2 ) inputs = (txt_embedding, noise, transf_matrices_inv, transf_matrices_s2, transf_matrices_inv_s2, label_one_hot) #print (" inputs " , inputs ) lr_fake, fake, _, _, _ = nn.parallel.data_parallel( netG, inputs, self.gpus) #print (" lr_fake, fake " , lr_fake, fake ) save_img_results(real_img_cpu, fake, epoch, self.image_dir) #print (" save_img_results(real_img_cpu, fake, epoch, self.image_dir) " , ) #print (" lr_fake is not None: " , lr_fake is not None ) if lr_fake is not None: save_img_results(None, lr_fake, epoch, self.image_dir) #print (" save_img_results(None, lr_fake, epoch, self.image_dir) " ) #end_t = time.time() #print ("batch time : " , (end_t - start_t)) end_t = time.time() print( '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f Total Time: %.2fsec ''' % (epoch, self.max_epoch, i, len(data_loader), errD.item(), errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake, (end_t - start_t))) if epoch % self.snapshot_interval == 0: save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) print("keyTime |||||||||||||||||||||||||||||||") print("epoch_time : ", time.time() - epoch_start_time) print("KeyTime |||||||||||||||||||||||||||||||") # save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) # self.summary_writer.close()