def save_model(self, netDCM, avg_param_C, netD, epoch): backup_para = copy_G_params(netDCM) load_params(netDCM, avg_param_C) torch.save(netDCM.state_dict(), '%s/netC_epoch_%d.pth' % (self.model_dir, epoch)) load_params(netDCM, backup_para) torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (self.model_dir, epoch)) print('Save C/D models.')
def save_model(self, netG, avg_param_G, netsD, epoch): backup_para = copy_G_params(netG) load_params(netG, avg_param_G) torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (self.model_dir, epoch)) load_params(netG, backup_para) # for i in range(len(netsD)): netD = netsD[i] torch.save(netD.state_dict(), '%s/netD%d.pth' % (self.model_dir, i)) print('Save G/Ds models.')
def train(self): text_encoder, image_encoder, netG, netD, start_epoch, VGG, netDCM = self.build_models( ) avg_param_C = copy_G_params(netDCM) optimizerC, optimizerD = self.define_optimizers(netDCM, netD) real_labels, fake_labels, match_labels = self.prepare_labels() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() gen_iterations = 0 for epoch in range(start_epoch, self.max_epoch): start_t = time.time() data_iter = iter(self.data_loader) step = 0 while step < self.num_batches: ###################################################### # (1) Prepare training data and Compute text embeddings ###################################################### data = data_iter.next() imgs, captions, cap_lens, class_ids, keys, wrong_caps, \ wrong_caps_len, wrong_cls_id = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef # matched text embeddings words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() # mismatched text embeddings w_words_embs, w_sent_emb = text_encoder( wrong_caps, wrong_caps_len, hidden) w_words_embs, w_sent_emb = w_words_embs.detach( ), w_sent_emb.detach() # image embeddings: regional and global region_features, cnn_code = image_encoder( imgs[cfg.TREE.BRANCH_NUM - 1]) mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Modify real images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, mu, logvar, h_code, c_code = netG( noise, sent_emb, words_embs, mask, cnn_code, region_features) real_img = imgs[cfg.TREE.BRANCH_NUM - 1] real_features = VGG(real_img)[0] fake_img = netDCM(h_code, real_features, sent_emb, words_embs,\ mask, c_code) ####################################################### # (3) Update D network ###################################################### errD = 0 D_logs = '' netD.zero_grad() errD = discriminator_loss(netD, imgs[cfg.TREE.BRANCH_NUM - 1], fake_img, sent_emb, real_labels, fake_labels, words_embs, cap_lens, image_encoder, class_ids, w_words_embs, wrong_caps_len, wrong_cls_id) errD.backward() optimizerD.step() D_logs = 'errD: %.2f ' % (errD) ####################################################### # (4) Update G network: maximize log(D(G(z))) ###################################################### # compute total loss for training G step += 1 gen_iterations += 1 netDCM.zero_grad() errC_total, C_logs = \ DCM_generator_loss(netD, image_encoder, fake_img, real_labels, words_embs, sent_emb, match_labels, cap_lens,\ class_ids, VGG, real_img) errC_total.backward() optimizerC.step() for p, avg_p in zip(netDCM.parameters(), avg_param_C): avg_p.mul_(0.999).add_(0.001, p.data) if gen_iterations % 100 == 0: print(D_logs + '\n' + C_logs) # save images if gen_iterations % 1000 == 0: backup_para = copy_G_params(netDCM) load_params(netDCM, avg_param_C) self.save_img_results(netG, fixed_noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, epoch, cnn_code, region_features, imgs, netDCM, real_features, name='average') load_params(netDCM, backup_para) end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_C: %.2f Time: %.2fs''' % (epoch, self.max_epoch, self.num_batches, errD, errC_total, end_t - start_t)) if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: self.save_model(netDCM, avg_param_C, netD, epoch) self.save_model(netDCM, avg_param_C, netD, self.max_epoch)
def train(self): text_encoder, image_encoder, netG, netsD, start_epoch, VGG = self.build_models( ) avg_param_G = copy_G_params(netG) optimizerG, optimizersD = self.define_optimizers(netG, netsD) real_labels, fake_labels, match_labels = self.prepare_labels() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() gen_iterations = 0 for epoch in range(start_epoch, self.max_epoch): start_t = time.time() data_iter = iter(self.data_loader) step = 0 while step < self.num_batches: ###################################################### # (1) Prepare training data and Compute text embeddings ###################################################### data = data_iter.next() imgs, captions, cap_lens, class_ids, keys, wrong_caps, \ wrong_caps_len, wrong_cls_id = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef # matched text embeddings words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() # mismatched text embeddings w_words_embs, w_sent_emb = text_encoder( wrong_caps, wrong_caps_len, hidden) w_words_embs, w_sent_emb = w_words_embs.detach( ), w_sent_emb.detach() # image features: regional and global region_features, cnn_code = image_encoder(imgs[len(netsD) - 1]) mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Modify real images ###################################################### noise.data.normal_(0, 1) fake_imgs, _, mu, logvar, _, _ = netG(noise, sent_emb, words_embs, mask, \ cnn_code, region_features) ####################################################### # (3) Update D network ###################################################### errD_total = 0 D_logs = '' for i in range(len(netsD)): netsD[i].zero_grad() errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], sent_emb, real_labels, fake_labels, words_embs, cap_lens, image_encoder, class_ids, w_words_embs, wrong_caps_len, wrong_cls_id) # backward and update parameters errD.backward(retain_graph=True) optimizersD[i].step() errD_total += errD D_logs += 'errD%d: %.2f ' % (i, errD) ####################################################### # (4) Update G network: maximize log(D(G(z))) ###################################################### # compute total loss for training G step += 1 gen_iterations += 1 netG.zero_grad() errG_total, G_logs = \ generator_loss(netsD, image_encoder, fake_imgs, real_labels, words_embs, sent_emb, match_labels, cap_lens,\ class_ids, VGG, imgs) kl_loss = KL_loss(mu, logvar) errG_total += kl_loss G_logs += 'kl_loss: %.2f ' % kl_loss # backward and update parameters errG_total.backward() optimizerG.step() for p, avg_p in zip(netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) if gen_iterations % 100 == 0: print(D_logs + '\n' + G_logs) # save images if gen_iterations % 1000 == 0: backup_para = copy_G_params(netG) load_params(netG, avg_param_G) self.save_img_results(netG, fixed_noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, epoch, cnn_code, region_features, imgs, name='average') load_params(netG, backup_para) end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (epoch, self.max_epoch, self.num_batches, errD_total, errG_total, end_t - start_t)) if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: self.save_model(netG, avg_param_G, netsD, epoch) self.save_model(netG, avg_param_G, netsD, self.max_epoch)
def train(self): text_encoder, image_encoder, netG, netsD, zsl_discriminator, start_epoch = self.build_models( ) avg_param_G = copy_G_params(netG) optimizerG, optimizersD = self.define_optimizers(netG, netsD) real_labels, fake_labels, match_labels = self.prepare_labels() batch_size = self.batch_size gen_iterations = 0 # gen_iterations = start_epoch * self.num_batches for epoch in range(start_epoch, self.max_epoch): start_t = time.time() data_iter = iter(self.data_loader) it = tqdm.tqdm(range( self.num_batches)) if tqdm is not None else range( self.num_batches) for step in it: # reset requires_grad to be trainable for all Ds # self.set_requires_grad_value(netsD, True) ###################################################### # (1) Prepare training data and Compute text embeddings ###################################################### data = data_iter.next() imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] ####################################################### # (2) Generate fake images ###################################################### fake_imgs, _ = netG(sent_emb, words_embs, mask) ####################################################### # (3) Update D network ###################################################### errD_total = 0 D_logs = '' for i in range(len(netsD)): netsD[i].zero_grad() errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], sent_emb, real_labels, fake_labels) # backward and update parameters errD.backward() optimizersD[i].step() errD_total += errD D_logs += 'errD%d: %.2f ' % (i, errD.item()) writer.add_scalar(f'd/errD/{i}', errD.item(), gen_iterations) ####################################################### # (4) Update G network: maximize log(D(G(z))) ###################################################### # compute total loss for training G # do not need to compute gradient for Ds # self.set_requires_grad_value(netsD, False) netG.zero_grad() errG_total, G_logs = generator_loss( netsD, image_encoder, zsl_discriminator, fake_imgs, real_labels, words_embs, sent_emb, match_labels, cap_lens, class_ids, writer=writer, global_step=gen_iterations, ) # backward and update parameters errG_total.backward() optimizerG.step() for p, avg_p in zip(netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) if gen_iterations % 100 == 0: LOGGER.info(f'{D_logs}\n{G_logs}') # save images if gen_iterations % 1000 == 0: backup_para = copy_G_params(netG) load_params(netG, avg_param_G) self.save_img_results(netG, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, epoch, name='average') load_params(netG, backup_para) # # self.save_img_results(netG, fixed_noise, sent_emb, # words_embs, mask, image_encoder, # captions, cap_lens, # epoch, name='current') gen_iterations += 1 end_t = time.time() info = (f'[{epoch}/{self.max_epoch}][{self.num_batches}] ' f'Loss_D: {errD_total.item():.2f} ' f'Loss_G: {errG_total.item():.2f} ' f'Time: {end_t - start_t:.2f}') LOGGER.info(info) if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: self.save_model(netG, avg_param_G, netsD, epoch) self.save_model(netG, avg_param_G, netsD, self.max_epoch)