def main(): # parse options parser = TrainOptions() opts = parser.parse() # daita loader print('\n--- load dataset ---') vocab = pickle.load( open(os.path.join(opts.vocab_path, '%s_vocab.pkl' % opts.data_name), 'rb')) vocab_size = len(vocab) opts.vocab_size = vocab_size torch.backends.cudnn.enabled = False # Load data loaders train_loader, val_loader = data.get_loaders(opts.data_name, vocab, opts.crop_size, opts.batch_size, opts.workers, opts) test_loader = data.get_test_loader('test', opts.data_name, vocab, opts.crop_size, opts.batch_size, opts.workers, opts) # model print('\n--- load subspace ---') subspace = model_2.VSE(opts) subspace.setgpu() print('\n--- load model ---') model = DRIT(opts) model.setgpu(opts.gpu) if opts.resume is None: #之前没有保存过模型 model.initialize() ep0 = -1 total_it = 0 else: ep0, total_it = model.resume(opts.resume) model.set_scheduler(opts, last_ep=ep0) ep0 += 1 print('start the training at epoch %d' % (ep0)) # saver for display and output saver = Saver(opts) # train print('\n--- train ---') max_it = 500000 score = 0.0 subspace.train_start() for ep in range(ep0, opts.pre_iter): print('-----ep:{} --------'.format(ep)) for it, (images, captions, lengths, ids) in enumerate(train_loader): if it >= opts.train_iter: break # input data images = images.cuda(opts.gpu).detach() captions = captions.cuda(opts.gpu).detach() img, cap = subspace.train_emb(images, captions, lengths, ids, pre=True) #[b,1024] subspace.pre_optimizer.zero_grad() img = img.view(images.size(0), -1, 32, 32) cap = cap.view(images.size(0), -1, 32, 32) model.pretrain_ae(img, cap) if opts.grad_clip > 0: clip_grad_norm(subspace.params, opts.grad_clip) subspace.pre_optimizer.step() for ep in range(ep0, opts.n_ep): subspace.train_start() adjust_learning_rate(opts, subspace.optimizer, ep) for it, (images, captions, lengths, ids) in enumerate(train_loader): if it >= opts.train_iter: break # input data images = images.cuda(opts.gpu).detach() captions = captions.cuda(opts.gpu).detach() img, cap = subspace.train_emb(images, captions, lengths, ids) #[b,1024] img = img.view(images.size(0), -1, 32, 32) cap = cap.view(images.size(0), -1, 32, 32) subspace.optimizer.zero_grad() for p in model.disA.parameters(): p.requires_grad = True for p in model.disB.parameters(): p.requires_grad = True for p in model.disA_attr.parameters(): p.requires_grad = True for p in model.disB_attr.parameters(): p.requires_grad = True for i in range(opts.niters_gan_d): #5 model.update_D(img, cap) for p in model.disA.parameters(): p.requires_grad = False for p in model.disB.parameters(): p.requires_grad = False for p in model.disA_attr.parameters(): p.requires_grad = False for p in model.disB_attr.parameters(): p.requires_grad = False for i in range(opts.niters_gan_enc): model.update_E(img, cap) #利用新的content损失函数 subspace.optimizer.step() print('total_it: %d (ep %d, it %d), lr %09f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr'])) total_it += 1 # decay learning rate if opts.n_ep_decay > -1: model.update_lr() # save result image #saver.write_img(ep, model) if (ep + 1) % opts.n_ep == 0: print('save model') filename = os.path.join(opts.result_dir, opts.name) model.save('%s/final_model.pth' % (filename), ep, total_it) torch.save(subspace.state_dict(), '%s/final_subspace.pth' % (filename)) elif (ep + 1) % 10 == 0: print('save model') filename = os.path.join(opts.result_dir, opts.name) model.save('%s/%s_model.pth' % (filename, str(ep + 1)), ep, total_it) torch.save(subspace.state_dict(), '%s/%s_subspace.pth' % (filename, str(ep + 1))) if (ep + 1) % opts.model_save_freq == 0: a = None b = None c = None d = None subspace.val_start() for it, (images, captions, lengths, ids) in enumerate(test_loader): if it >= opts.val_iter: break images = images.cuda(opts.gpu).detach() captions = captions.cuda(opts.gpu).detach() img_emb, cap_emb = subspace.forward_emb(images, captions, lengths, volatile=True) img = img_emb.view(images.size(0), -1, 32, 32) cap = cap_emb.view(images.size(0), -1, 32, 32) image1, text1 = model.test_model2(img, cap) img2 = image1.view(images.size(0), -1) cap2 = text1.view(images.size(0), -1) if a is None: a = np.zeros( (opts.val_iter * opts.batch_size, img_emb.size(1))) b = np.zeros( (opts.val_iter * opts.batch_size, cap_emb.size(1))) c = np.zeros( (opts.val_iter * opts.batch_size, img2.size(1))) d = np.zeros( (opts.val_iter * opts.batch_size, cap2.size(1))) a[ids] = img_emb.data.cpu().numpy().copy() b[ids] = cap_emb.data.cpu().numpy().copy() c[ids] = img2.data.cpu().numpy().copy() d[ids] = cap2.data.cpu().numpy().copy() aa = torch.from_numpy(a) bb = torch.from_numpy(b) cc = torch.from_numpy(c) dd = torch.from_numpy(d) (r1, r5, r10, medr, meanr) = i2t(aa, bb, measure=opts.measure) print('test640: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format( medr, r1, r5, r10)) (r1i, r5i, r10i, medri, meanr) = t2i(aa, bb, measure=opts.measure) print('test640: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format( medri, r1i, r5i, r10i)) (r2, r3, r4, m1, m2) = i2t(cc, dd, measure=opts.measure) print('test640: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format( m1, r2, r3, r4)) (r2i, r3i, r4i, m1i, m2i) = t2i(cc, dd, measure=opts.measure) print('test640: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format( m1i, r2i, r3i, r4i)) curr = r2 + r3 + r4 + r2i + r3i + r4i if curr > score: score = curr print('save model') filename = os.path.join(opts.result_dir, opts.name) model.save('%s/best_model.pth' % (filename), ep, total_it) torch.save(subspace.state_dict(), '%s/subspace.pth' % (filename)) a = None b = None c = None d = None for it, (images, captions, lengths, ids) in enumerate(test_loader): images = images.cuda(opts.gpu).detach() captions = captions.cuda(opts.gpu).detach() img_emb, cap_emb = subspace.forward_emb(images, captions, lengths, volatile=True) img = img_emb.view(images.size(0), -1, 32, 32) cap = cap_emb.view(images.size(0), -1, 32, 32) image1, text1 = model.test_model2(img, cap) img2 = image1.view(images.size(0), -1) cap2 = text1.view(images.size(0), -1) if a is None: a = np.zeros((len(test_loader.dataset), img_emb.size(1))) b = np.zeros((len(test_loader.dataset), cap_emb.size(1))) c = np.zeros((len(test_loader.dataset), img2.size(1))) d = np.zeros((len(test_loader.dataset), cap2.size(1))) a[ids] = img_emb.data.cpu().numpy().copy() b[ids] = cap_emb.data.cpu().numpy().copy() c[ids] = img2.data.cpu().numpy().copy() d[ids] = cap2.data.cpu().numpy().copy() aa = torch.from_numpy(a) bb = torch.from_numpy(b) cc = torch.from_numpy(c) dd = torch.from_numpy(d) (r1, r5, r10, medr, meanr) = i2t(aa, bb, measure=opts.measure) print('test5000: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format( medr, r1, r5, r10)) (r1i, r5i, r10i, medri, meanr) = t2i(aa, bb, measure=opts.measure) print('test5000: subspace: Med:{}, r1:{}, r5:{}, r10:{}'.format( medri, r1i, r5i, r10i)) (r2, r3, r4, m1, m2) = i2t(cc, dd, measure=opts.measure) print('test5000: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format( m1, r2, r3, r4)) (r2i, r3i, r4i, m1i, m2i) = t2i(cc, dd, measure=opts.measure) print('test5000: encoder: Med:{}, r1:{}, r5:{}, r10:{}'.format( m1i, r2i, r3i, r4i)) return
def main(): # parse options parser = TrainOptions() opts = parser.parse() # data loader print('\n--- load dataset ---') if opts.multi_modal: dataset = dataset_unpair_multi(opts) else: dataset = dataset_unpair(opts) train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.nThreads) # model print('\n--- load model ---') model = DRIT(opts) model.setgpu(opts.gpu) if opts.resume is None: model.initialize() ep0 = -1 total_it = 0 else: ep0, total_it = model.resume(opts.resume) model.set_scheduler(opts, last_ep=ep0) ep0 += 1 print('start the training at epoch %d' % (ep0)) # saver for display and output saver = Saver(opts) # train print('\n--- train ---') max_it = 500000 for ep in range(ep0, opts.n_ep): for it, (images_a, images_b) in enumerate(train_loader): if images_a.size(0) != opts.batch_size or images_b.size( 0) != opts.batch_size: continue # input data images_a = images_a.cuda(opts.gpu).detach() images_b = images_b.cuda(opts.gpu).detach() # update model if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2: model.update_D_content(images_a, images_b) continue else: model.update_D(images_a, images_b) model.update_EG() # save to display file if not opts.no_display_img and not opts.multi_modal: saver.write_display(total_it, model) print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr'])) total_it += 1 if total_it >= max_it: # saver.write_img(-1, model) saver.write_model(-1, model) break # decay learning rate if opts.n_ep_decay > -1: model.update_lr() # save result image if not opts.multi_modal: saver.write_img(ep, model) # Save network weights saver.write_model(ep, total_it, model) return
def main(): debug_mode=False # parse options parser = TrainOptions() opts = parser.parse() # daita loader print('\n--- load dataset ---') dataset = dataset_unpair(opts) train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.nThreads) ''' 通过检查dataset_unpair,我们发现: 图像是先缩放到256,256,然后再随机裁剪出216,216的patch,(测试时是从中心裁剪) ''' # model print('\n--- load model ---') model = DRIT(opts) if not debug_mode: model.setgpu(opts.gpu) if opts.resume is None: model.initialize() ep0 = -1 total_it = 0 else: ep0, total_it = model.resume(opts.resume) model.set_scheduler(opts, last_ep=ep0) ep0 += 1 print('start the training at epoch %d' % (ep0)) # saver for display and output saver = Saver(opts) # train print('\n--- train ---') max_it = 500000 for ep in range(ep0, opts.n_ep): ''' images_a,images_b: 2,3,216,216 ''' for it, (images_a, images_b) in enumerate(train_loader): # 假如正好拿到了残次的剩余的一两个样本,就跳过,重新取样 if images_a.size(0) != opts.batch_size or images_b.size(0) != opts.batch_size: continue # input data if not debug_mode: images_a = images_a.cuda(opts.gpu).detach() # 这里进行detach,可能是为了避免计算不需要的梯度,节省显存 images_b = images_b.cuda(opts.gpu).detach() # update model 按照默认设置,1/3的iter更新内容判别器,2/3的iter更新D和EG if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2: model.update_D_content(images_a, images_b) continue else: model.update_D(images_a, images_b) model.update_EG() # save to display file if not opts.no_display_img: saver.write_display(total_it, model) print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr'])) sys.stdout.flush() total_it += 1 if total_it >= max_it: saver.write_img(-1, model) saver.write_model(-1, model) break # decay learning rate if opts.n_ep_decay > -1: model.update_lr() # save result image saver.write_img(ep, model) # Save network weights saver.write_model(ep, total_it, model) return