def test_scan(self): dae = DAE() vae = VAE(dae) scan = SCAN(dae, vae) vars = scan.get_vars() # Check size of optimizing vars self.assertEqual(len(vars), 6 + 4)
def main(argv): data_manager = DataManager() data_manager.prepare() dae = DAE() vae = VAE(dae, beta=flags.vae_beta) scan = SCAN(dae, vae, beta=flags.scan_beta, lambd=flags.scan_lambda) scan_recomb = SCANRecombinator(dae, vae, scan) dae_saver = CheckPointSaver(flags.checkpoint_dir, "dae", dae.get_vars()) vae_saver = CheckPointSaver(flags.checkpoint_dir, "vae", vae.get_vars()) scan_saver = CheckPointSaver(flags.checkpoint_dir, "scan", scan.get_vars()) scan_recomb_saver = CheckPointSaver(flags.checkpoint_dir, "scan_recomb", scan_recomb.get_vars()) sess = tf.Session() # Initialze variables init = tf.global_variables_initializer() sess.run(init) # For Tensorboard log summary_writer = tf.summary.FileWriter(flags.log_file, sess.graph) # Load from checkpoint dae_saver.load(sess) vae_saver.load(sess) scan_saver.load(sess) scan_recomb_saver.load(sess) # Train if flags.train_dae: train_dae(sess, dae, data_manager, dae_saver, summary_writer) if flags.train_vae: train_vae(sess, vae, data_manager, vae_saver, summary_writer) disentangle_check(sess, vae, data_manager) if flags.train_scan: train_scan(sess, scan, data_manager, scan_saver, summary_writer) sym2img_check(sess, scan, data_manager) img2sym_check(sess, scan, data_manager) if flags.train_scan_recomb: train_scan_recomb(sess, scan_recomb, data_manager, scan_recomb_saver, summary_writer) recombination_check(sess, scan_recomb, data_manager) sess.close()
def load_model(model_path, device): # load model and options checkpoint = torch.load(model_path, map_location=device) opt = checkpoint['opt'] # add because div_transform is not present in model d = vars(opt) d['div_transform'] = False # construct model model = SCAN(opt) # load model state model.load_state_dict(checkpoint['model']) return model, opt
def test_scan(self): dae = DAE() vae = VAE(dae) scan = SCAN(dae, vae) vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "scan") # Check size of optimizing vars self.assertEqual(len(vars), 6+4)
def test_scan_recombinator(self): dae = DAE() vae = VAE(dae) scan = SCAN(dae, vae) scan_recomb = SCANRecombinator(dae, vae, scan) vars = scan_recomb.get_vars() # Check size of optimizing vars self.assertEqual(len(vars), 4)
def load_model(model_path, device): # load model and options checkpoint = torch.load(model_path, map_location=device) opt = checkpoint['opt'] # add because div_transform is not present in model # d = vars(opt) # d["layernorm"] = False # d['div_transform'] = False # d["net"] = "alex" # d["txt_enc"] = "basic" # d["diversity_loss"] = None # construct model model = SCAN(opt) # load model state model.load_state_dict(checkpoint['model']) return model, opt
def main(argv): data_manager = DataManager() data_manager.prepare() dae = DAE() vae = VAE(dae) scan = SCAN(dae, vae) dae_saver = CheckPointSaver(CHECKPOINT_DIR, "dae", dae.get_vars()) vae_saver = CheckPointSaver(CHECKPOINT_DIR, "vae", vae.get_vars()) scan_saver = CheckPointSaver(CHECKPOINT_DIR, "scan", scan.get_vars()) sess = tf.Session() # Initialze variables init = tf.global_variables_initializer() sess.run(init) # For Tensorboard log summary_writer = tf.summary.FileWriter(LOG_FILE, sess.graph) # Load from checkpoint dae_saver.load(sess) vae_saver.load(sess) scan_saver.load(sess) # Train train_dae(sess, dae, data_manager, dae_saver, summary_writer) train_vae(sess, vae, data_manager, vae_saver, summary_writer) disentangle_check(sess, vae, data_manager) train_scan(sess, scan, data_manager, scan_saver, summary_writer) sym2img_check(sess, scan, data_manager) img2sym_check(sess, scan, data_manager) sess.close()
def main(): # Hyper Parameters parser = argparse.ArgumentParser() parser.add_argument('--data_path', default='./data/', help='path to datasets') parser.add_argument('--model_path', default='./data/', help='path to model') parser.add_argument('--split', default='test', help='val/test') parser.add_argument('--gpuid', default=0., type=str, help='gpuid') parser.add_argument('--fold5', action='store_true', help='fold5') opts = parser.parse_args() device_id = opts.gpuid print("use GPU:", device_id) os.environ['CUDA_VISIBLE_DEVICES'] = str(device_id) device_id = 0 torch.cuda.set_device(0) # load model and options checkpoint = torch.load(opts.model_path) opt = checkpoint['opt'] opt.loss_verbose = False opt.split = opts.split opt.data_path = opts.data_path opt.fold5 = opts.fold5 # load vocabulary used by the model vocab = deserialize_vocab( os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) opt.vocab_size = len(vocab) # construct model model = SCAN(opt) model.cuda() model = nn.DataParallel(model) # load model state model.load_state_dict(checkpoint['model']) print('Loading dataset') data_loader = data.get_test_loader(opt.split, opt.data_name, vocab, opt.batch_size, opt.workers, opt) print(opt) print('Computing results...') evaluation.evalrank(model.module, data_loader, opt, split=opt.split, fold5=opt.fold5)
def evalrank(input_string, img_feature, how_many, model_path, data_path=None, split='dev', fold5=False, gpu_num=None): """ Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold cross-validation is done (only for MSCOCO). Otherwise, the full data is used for evaluation. """ # load model and options s_t = time.time() checkpoint = torch.load(model_path) opt = checkpoint['opt'] print(opt) print("%s seconds taken to load checkpoint" % (time.time() - s_t)) if data_path is not None: opt.data_path = data_path # construct model model = SCAN(opt) # load model state model.load_state_dict(checkpoint['model']) # local dir # opt.vocab_path = '/home/ivy/hard2/scan_data/vocab' # docker dir opt.vocab_path = '/scan/SCAN/data/vocab' # load vocabulary used by the model vocab = deserialize_vocab( os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) opt.vocab_size = len(vocab) # print("opt.vocab_size ", opt.vocab_size) print("Loading npy file") start_time = time.time() # local dir # img_embs = np.load('/home/ivy/hard2/scan_out/img_embs.npy') img_embs = img_feature # docker dir #img_embs = np.load('/scan/SCAN/numpy_data/img_embs.npy') print("%s seconds takes to load npy file" % (time.time() - start_time)) captions = [] captions.append(str(input_string)) tokens = nltk.tokenize.word_tokenize(str(captions).lower().decode('utf-8')) caption = [] caption.append(vocab('<start>')) caption.extend([vocab(token) for token in tokens]) caption.append(vocab('<end>')) target = [] for batch in range(opt.batch_size): target.append(caption) target = torch.Tensor(target).long() print('Calculating results...') start_time = time.time() cap_embs, cap_len = encode_data(model, target, opt.batch_size) cap_lens = cap_len[0] print("%s seconds takes to calculate results" % (time.time() - start_time)) print("Caption length with start and end index : ", cap_lens) print('Images: %d, Captions: %d' % (img_embs.shape[0], cap_embs.shape[0])) if not fold5: img_embs = np.array(img_embs) start = time.time() if opt.cross_attn == 't2i': sims = shard_xattn_t2i(img_embs, cap_embs, cap_lens, opt, shard_size=128) elif opt.cross_attn == 'i2t': sims = shard_xattn_i2t(img_embs, cap_embs, cap_lens, opt, shard_size=128) else: raise NotImplementedError end = time.time() print("calculate similarity time:", end - start) # top_10 = np.argsort(sims, axis=0)[-10:][::-1].flatten() top_n = np.argsort(sims, axis=0)[-(how_many):][::-1].flatten() final_result = list(top_n) # 5fold cross-validation, only for MSCOCO else: for i in range(10): if i < 9: img_embs_shard = img_embs[i * (img_embs.shape[0] // 10):(i + 1) * (img_embs.shape[0] // 10)] else: img_embs_shard = img_embs[i * (img_embs.shape[0] // 10):] cap_embs_shard = cap_embs cap_lens_shard = cap_lens start = time.time() if opt.cross_attn == 't2i': sims = shard_xattn_t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) elif opt.cross_attn == 'i2t': sims = shard_xattn_i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) else: raise NotImplementedError end = time.time() print("calculate similarity time:", end - start) top_10 = np.argsort(sims, axis=0)[-10:][::-1].flatten() print("Top 10 list for iteration #%d : " % (i + 1) + str(top_10 + 5000 * i)) # r, rt0 = i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) # print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) # ri, rti0 = t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) # print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) # # if i == 0: # rt, rti = rt0, rti0 # ar = (r[0] + r[1] + r[2]) / 3 # ari = (ri[0] + ri[1] + ri[2]) / 3 # rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] # print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) # results += [list(r) + list(ri) + [ar, ari, rsum]] # # print("-----------------------------------") # print("Mean metrics: ") # mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) # print("rsum: %.1f" % (mean_metrics[10] * 6)) # print("Average i2t Recall: %.1f" % mean_metrics[11]) # print("Image to text: %.1f %.1f %.1f %.1f %.1f" % # mean_metrics[:5]) # print("Average t2i Recall: %.1f" % mean_metrics[12]) # print("Text to image: %.1f %.1f %.1f %.1f %.1f" % # mean_metrics[5:10]) # torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar') return final_result
def start_experiment(opt, seed): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) print("Let's use", torch.cuda.device_count(), "GPUs!") print("Number threads:", torch.get_num_threads()) # Load Vocabulary Wrapper, create dictionary that can switch between ids and words vocab = deserialize_vocab("{}/{}/{}_vocab_{}.json".format( opt.vocab_path, opt.clothing, opt.data_name, opt.version)) opt.vocab_size = len(vocab) # Load data loaders train_loader, val_loader = data_ken.get_loaders(opt.data_name, vocab, opt.batch_size, opt.workers, opt) # Construct the model model = SCAN(opt) # save hyperparameters in file save_hyperparameters(opt.logger_name, opt) best_rsum = 0 start_epoch = 0 # optionally resume from a checkpoint if opt.resume: if os.path.isfile(opt.resume): print("=> loading checkpoint '{}'".format(opt.resume)) checkpoint = torch.load(opt.resume) start_epoch = checkpoint['epoch'] + 1 best_rsum = checkpoint['best_rsum'] model.load_state_dict(checkpoint['model']) # Eiters is used to show logs as the continuation of another # training model.Eiters = checkpoint['Eiters'] print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format( opt.resume, start_epoch, best_rsum)) validate(opt, val_loader, model) else: print("=> no checkpoint found at '{}'".format(opt.resume)) # Train the Model for epoch in range(start_epoch, opt.num_epochs): print(opt.logger_name) print(opt.model_name) adjust_learning_rate(opt, model.optimizer, epoch) # train for one epoch train(opt, train_loader, model, epoch, val_loader) # evaluate on validation set rsum = validate(opt, val_loader, model) # remember best R@ sum and save checkpoint is_best = rsum > best_rsum best_rsum = max(rsum, best_rsum) if not os.path.exists(opt.model_name): os.mkdir(opt.model_name) last_epoch = False if epoch == (opt.num_epochs - 1): last_epoch = True # only save when best epoch, or last epoch for further training if is_best or last_epoch: save_checkpoint( { 'epoch': epoch, 'model': model.state_dict(), 'best_rsum': best_rsum, 'opt': opt, 'Eiters': model.Eiters, }, is_best, last_epoch, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/') return best_rsum
def evalrank(model_path, data_path=None, split='dev', fold5=False): """ Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold cross-validation is done (only for MSCOCO). Otherwise, the full data is used for evaluation. """ # load model and options checkpoint = torch.load(model_path) opt = checkpoint['opt'] print(opt) if data_path is not None: opt.data_path = data_path # load vocabulary used by the model vocab = deserialize_vocab( os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) opt.vocab_size = len(vocab) # construct model model = SCAN(opt) # load model state model.load_state_dict(checkpoint['model']) print('Loading dataset') data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size, opt.workers, opt) print('Computing results...') img_embs, cap_embs, cap_lens = encode_data(model, data_loader) print('Images: %d, Captions: %d' % (img_embs.shape[0] / 5, cap_embs.shape[0])) if not fold5: # no cross-validation, full evaluation img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) start = time.time() if opt.cross_attn == 't2i': sims = shard_xattn_t2i(img_embs, cap_embs, cap_lens, opt, shard_size=128) elif opt.cross_attn == 'i2t': sims = shard_xattn_i2t(img_embs, cap_embs, cap_lens, opt, shard_size=128) else: raise NotImplementedError end = time.time() print("calculate similarity time:", end - start) r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True) ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True) ar = (r[0] + r[1] + r[2]) / 3 ari = (ri[0] + ri[1] + ri[2]) / 3 rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] print("rsum: %.1f" % rsum) print("Average i2t Recall: %.1f" % ar) print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r) print("Average t2i Recall: %.1f" % ari) print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri) else: # 5fold cross-validation, only for MSCOCO results = [] for i in range(5): img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] start = time.time() if opt.cross_attn == 't2i': sims = shard_xattn_t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) elif opt.cross_attn == 'i2t': sims = shard_xattn_i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) else: raise NotImplementedError end = time.time() print("calculate similarity time:", end - start) r, rt0 = i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) ri, rti0 = t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) if i == 0: rt, rti = rt0, rti0 ar = (r[0] + r[1] + r[2]) / 3 ari = (ri[0] + ri[1] + ri[2]) / 3 rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) results += [list(r) + list(ri) + [ar, ari, rsum]] print("-----------------------------------") print("Mean metrics: ") mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) print("rsum: %.1f" % (mean_metrics[10] * 6)) print("Average i2t Recall: %.1f" % mean_metrics[11]) print("Image to text: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[:5]) print("Average t2i Recall: %.1f" % mean_metrics[12]) print("Text to image: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[5:10]) torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
def main(args): model_path = "{}/{}/seed1/checkpoint/{}".format(args.model_path, args.run, args.checkpoint) # load model and options checkpoint = torch.load(model_path, map_location=torch.device('cpu')) opt = checkpoint['opt'] # add because basic is not present in model d = vars(opt) d['basic'] = False run = args.run data_path = "{}/{}".format(args.data_path, args.data_name) nr_examples = args.nr_examples version = opt.version clothing = opt.clothing if opt.trans: plot_folder = "plots_trans" else: plot_folder = "plots_scan" plot_path = '{}/{}_{}'.format(plot_folder, version, run) caption_test_path = "{}/{}/data_captions_{}_test.txt".format( data_path, clothing, version) image_path = "{}".format(args.image_folder) vocab_path = "{}/{}".format(args.vocab_path, clothing) data_folder = "../data" if not os.path.exists(plot_path): os.makedirs(plot_path) # change image paths from lisa folders to local folders opt.data_path = data_folder opt.image_path = image_path opt.vocab_path = vocab_path print(opt) # construct model model = SCAN(opt) # load model state model.load_state_dict(checkpoint['model']) try: embs = torch.load("{}/embs/embs_{}_{}.pth.tar".format( plot_folder, run, version), map_location=('cpu')) print("loading embeddings") img_embs = embs["img_embs"] cap_embs = embs["cap_embs"] cap_lens = embs["cap_lens"] freqs = embs["freqs"] except: print("Create embeddings") img_embs, cap_embs, cap_lens, freqs = get_embs(opt, model, run, version, data_path, plot_folder, vocab_path=vocab_path) print('Images: %d, Captions: %d' % (img_embs.shape[0], cap_embs.shape[0])) temp = torch.load("{}/ranks_{}_{}.pth.tar".format(plot_folder, run, version), map_location=('cpu')) rt = temp["rt"] rti = temp["rti"] attn = temp["attn"] t2i_switch = temp["t2i_switch"] r_i2t = calculate_r(rt[0], "i2t") r_t2i = calculate_r(rti[0], "t2i") top1_rt = rt[1] top1_rti = rti[1] if args.focus_subset: indx = get_indx_subset(caption_test_path, args.word_asked) rs_i2t = calculate_r(rt[0][indx], "i2t") rs_t2i = calculate_r(rti[0][indx], "t2i") print_result_subset(rs_i2t, r_i2t, "i2t", args.word_asked) print_result_subset(rs_t2i, r_t2i, "t2i", args.word_asked) rnd_indx = get_random_indx(nr_examples, len(indx)) rnd = [indx[i] for i in rnd_indx] else: rnd = get_random_indx(nr_examples, len(top1_rt)) # dictionary to turn test_ids to data_ids test_id2data = {} # find the caption and image with every id in the test file {caption_id : (image_id, caption)} with open(caption_test_path, newline='') as file: caption_reader = csv.reader(file, delimiter='\t') for i, line in enumerate(caption_reader): test_id2data[i] = (line[0], line[1]) h5_images = get_h5_images(args.data_name, data_path) # get the matches matches_i2t = get_matches_i2t(top1_rt, test_id2data, nr_examples, rnd) matches_t2i = get_matches_t2i(top1_rti, test_id2data, nr_examples, rnd) # get id for file name unique_id = get_id(plot_path, "i2t", run) # plot image and caption together show_plots(matches_i2t, len(matches_i2t), "i2t", run, version, plot_path, args, clothing, h5_images, unique_id) show_plots(matches_t2i, len(matches_t2i), "t2i", run, version, plot_path, args, clothing, h5_images, unique_id) for i in range(len(rnd)): wanted_id = rnd[i] target_id = get_target_id(top1_rt, top1_rti, t2i_switch, wanted_id) attn = get_attn(img_embs, cap_embs, cap_lens, wanted_id, target_id, opt, t2i_switch, freqs) if t2i_switch: words_caption = get_captions(test_id2data, wanted_id) image_segs = get_image_segs(target_id, test_id2data, args, opt, model, h5_images) match_t2i_viz(attn, wanted_id, target_id, test_id2data, run, version, plot_path, clothing, words_caption, image_segs) else: words_caption = get_captions(test_id2data, target_id) image_segs = get_image_segs(wanted_id, test_id2data, args, opt, model, h5_images) match_i2t_viz(attn, wanted_id, target_id, test_id2data, run, version, plot_path, words_caption, image_segs)
def main(): # Hyper Parameters parser = argparse.ArgumentParser() parser.add_argument( '--data_path', default='/data3/zhangyf/cross_modal_retrieval/SCAN/data', help='path to datasets') parser.add_argument('--data_name', default='f30k_precomp', help='{coco,f30k}_precomp') parser.add_argument( '--vocab_path', default='/data3/zhangyf/cross_modal_retrieval/SCAN/vocab/', help='Path to saved vocabulary json files.') parser.add_argument('--margin', default=0.2, type=float, help='Rank loss margin.') parser.add_argument('--num_epochs', default=20, type=int, help='Number of training epochs.') parser.add_argument('--batch_size', default=128, type=int, help='Size of a training mini-batch.') parser.add_argument('--word_dim', default=300, type=int, help='Dimensionality of the word embedding.') parser.add_argument('--decoder_dim', default=512, type=int, help='Dimensionality of the word embedding.') parser.add_argument('--embed_size', default=1024, type=int, help='Dimensionality of the joint embedding.') parser.add_argument('--grad_clip', default=2., type=float, help='Gradient clipping threshold.') parser.add_argument('--num_layers', default=1, type=int, help='Number of GRU layers.') parser.add_argument('--learning_rate', default=.0002, type=float, help='Initial learning rate.') parser.add_argument('--lr_update', default=10, type=int, help='Number of epochs to update the learning rate.') parser.add_argument('--workers', default=4, type=int, help='Number of data loader workers.') parser.add_argument('--log_step', default=30, type=int, help='Number of steps to print and record the log.') parser.add_argument('--val_step', default=500, type=int, help='Number of steps to run validation.') parser.add_argument('--logger_name', default='./runs/runX/log', help='Path to save Tensorboard log.') parser.add_argument('--model_name', default='./runs/runX/checkpoint', help='Path to save the model.') parser.add_argument( '--resume', default= '/data3/zhangyf/cross_modal_retrieval/vsepp_next_train_12_31_f30k/run/coco_vse++_ft_128_f30k_next/model_best.pth.tar', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--max_violation', action='store_true', help='Use max instead of sum in the rank loss.') parser.add_argument('--img_dim', default=2048, type=int, help='Dimensionality of the image embedding.') parser.add_argument('--no_imgnorm', action='store_true', help='Do not normalize the image embeddings.') parser.add_argument('--no_txtnorm', action='store_true', help='Do not normalize the text embeddings.') parser.add_argument('--precomp_enc_type', default="basic", help='basic|weight_norm') parser.add_argument('--reset_train', action='store_true', help='Ensure the training is always done in ' 'train mode (Not recommended).') parser.add_argument('--finetune', action='store_true', help='Fine-tune the image encoder.') parser.add_argument('--cnn_type', default='resnet152', help="""The CNN used for image encoder (e.g. vgg19, resnet152)""") parser.add_argument('--crop_size', default=224, type=int, help='Size of an image crop as the CNN input.') opt = parser.parse_args() print(opt) logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) tb_logger.configure(opt.logger_name, flush_secs=5) # Load Vocabulary Wrapper vocab = pickle.load( open(os.path.join(opt.vocab_path, '%s_vocab.pkl' % opt.data_name), 'rb')) opt.vocab_size = len(vocab) # Load data loaders train_loader, val_loader = data.get_loaders(opt.data_name, vocab, opt.batch_size, opt.workers, opt) # Construct the model model = SCAN(opt) # optionally resume from a checkpoint if opt.resume: if os.path.isfile(opt.resume): print("=> loading checkpoint '{}'".format(opt.resume)) checkpoint = torch.load(opt.resume) start_epoch = checkpoint['epoch'] best_rsum = checkpoint['best_rsum'] model.load_state_dict(checkpoint['model']) # Eiters is used to show logs as the continuation of another # training model.Eiters = checkpoint['Eiters'] print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format( opt.resume, start_epoch, best_rsum)) validate(opt, val_loader, model) else: print("=> no checkpoint found at '{}'".format(opt.resume)) # Train the Model best_rsum = 0 for epoch in range(opt.num_epochs): print(opt.logger_name) print(opt.model_name) adjust_learning_rate(opt, model.optimizer, epoch) # train for one epoch bset_rsum = train(opt, train_loader, model, epoch, val_loader, best_rsum) # evaluate on validation set rsum = validate(opt, val_loader, model) # remember best R@ sum and save checkpoint is_best = rsum > best_rsum best_rsum = max(rsum, best_rsum) if not os.path.exists(opt.model_name): os.mkdir(opt.model_name) save_checkpoint( { 'epoch': epoch + 1, 'model': model.state_dict(), 'best_rsum': best_rsum, 'opt': opt, 'Eiters': model.Eiters, }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/')
def main(): # Hyper Parameters parser = argparse.ArgumentParser() parser.add_argument('--data_path', default='./data/', help='path to datasets') parser.add_argument('--data_name', default='precomp', help='{coco,f30k}_precomp') parser.add_argument('--vocab_path', default='./vocab/', help='Path to saved vocabulary json files.') parser.add_argument('--margin', default=0.2, type=float, help='Rank loss margin.') parser.add_argument('--num_epochs', default=30, type=int, help='Number of training epochs.') parser.add_argument('--batch_size', default=128, type=int, help='Size of a training mini-batch.') parser.add_argument('--word_dim', default=300, type=int, help='Dimensionality of the word embedding.') parser.add_argument('--embed_size', default=1024, type=int, help='Dimensionality of the joint embedding.') parser.add_argument('--grad_clip', default=2., type=float, help='Gradient clipping threshold.') parser.add_argument('--num_layers', default=1, type=int, help='Number of GRU layers.') parser.add_argument('--learning_rate', default=.0002, type=float, help='Initial learning rate.') parser.add_argument('--lr_update', default=15, type=int, help='Number of epochs to update the learning rate.') parser.add_argument('--workers', default=10, type=int, help='Number of data loader workers.') parser.add_argument('--log_step', default=10, type=int, help='Number of steps to print and record the log.') parser.add_argument('--val_step', default=500, type=int, help='Number of steps to run validation.') parser.add_argument('--logger_name', default='./runs/runX/log', help='Path to save Tensorboard log.') parser.add_argument('--model_name', default='./runs/runX/checkpoint', help='Path to save the model.') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--max_violation', action='store_true', help='Use max instead of sum in the rank loss.') parser.add_argument('--img_dim', default=2048, type=int, help='Dimensionality of the image embedding.') parser.add_argument('--no_imgnorm', action='store_true', help='Do not normalize the image embeddings.') parser.add_argument('--no_txtnorm', action='store_true', help='Do not normalize the text embeddings.') parser.add_argument( '--raw_feature_norm', default="clipped_l2norm", help='clipped_l2norm|l2norm|clipped_l1norm|l1norm|no_norm|softmax') parser.add_argument('--agg_func', default="LogSumExp", help='LogSumExp|Mean|Max|Sum') parser.add_argument('--cross_attn', default="t2i", help='t2i|i2t') parser.add_argument('--precomp_enc_type', default="basic", help='basic|weight_norm') parser.add_argument('--bi_gru', action='store_true', help='Use bidirectional GRU.') parser.add_argument('--lambda_lse', default=6., type=float, help='LogSumExp temp.') parser.add_argument('--lambda_softmax', default=9., type=float, help='Attention softmax temperature.') opt = parser.parse_args() print(opt) logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO) tb_logger.configure(opt.logger_name, flush_secs=5) # Load Vocabulary Wrapper vocab = deserialize_vocab( os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) opt.vocab_size = len(vocab) # Load data loaders train_loader, val_loader = data.get_loaders(opt.data_name, vocab, opt.batch_size, opt.workers, opt) # Construct the model model = SCAN(opt) # optionally resume from a checkpoint if opt.resume: if os.path.isfile(opt.resume): print("=> loading checkpoint '{}'".format(opt.resume)) checkpoint = torch.load(opt.resume) start_epoch = checkpoint['epoch'] best_rsum = checkpoint['best_rsum'] model.load_state_dict(checkpoint['model']) # Eiters is used to show logs as the continuation of another # training model.Eiters = checkpoint['Eiters'] print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})".format( opt.resume, start_epoch, best_rsum)) validate(opt, val_loader, model) else: print("=> no checkpoint found at '{}'".format(opt.resume)) # Train the Model best_rsum = 0 for epoch in range(opt.num_epochs): print(opt.logger_name) print(opt.model_name) adjust_learning_rate(opt, model.optimizer, epoch) # train for one epoch train(opt, train_loader, model, epoch, val_loader) # evaluate on validation set rsum = validate(opt, val_loader, model) # remember best R@ sum and save checkpoint is_best = rsum > best_rsum best_rsum = max(rsum, best_rsum) if not os.path.exists(opt.model_name): os.mkdir(opt.model_name) save_checkpoint( { 'epoch': epoch + 1, 'model': model.state_dict(), 'best_rsum': best_rsum, 'opt': opt, 'Eiters': model.Eiters, }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch), prefix=opt.model_name + '/')
vae_test_generator=data.DataLoader(test_set,**config.generator_params) train.train_bvae(BVAE_net,optim_bvae,vae_training_generator,vae_test_generator,config.BVAE_CHECKPOINT,config.BVAE_TRAIN_EPOCH,writer,config.BVAE_LOG) else: utils.load_model(config.BVAE_LOAD_PATH,BVAE_net,optim_bvae) """for batch_id,batch in enumerate(0,vae_test_generator): utils.visualize_recon(BVAE_net,DAE_net,channel_mean,channel_std,config.VIS_RECON_PATH) utils.latent_traversal(BVAE_net,DAE_net,channel_mean,channel_std,config.VIS_LATENT_TRAVERSAL) """ SCAN_net=SCAN(51,100,32,1,10,BVAE_net) SCAN_net.cuda() optim_scan = torch.optim.Adam(list(SCAN_net.encoder.parameters())+list(SCAN_net.decoder.parameters()),lr=1e-4) oh_train_set=dataloader.SCANdata(one_hots_train,train_data) oh_training_generator=data.DataLoader(oh_train_set,**config.scan_generator_params) oh_test_set=dataloader.SCANdata(one_hots_test,test_data) oh_test_generator=data.DataLoader(oh_test_set,**config.scan_generator_params) train.train_scan(SCAN_net,optim_scan,oh_training_generator,oh_test_generator,config.SCAN_CHECKPOINT,10,writer,"output_file_path") recomb_train_set=dataloader.RECOMBdata(data_set,perm,[15,15,15,2],20000)
print("Epoch:", '%04d' % (epoch + 1), "reconstr=", "{:.3f}".format(average_reconstr_loss), "latent0=", "{:.3f}".format(average_latent_loss0), "latent1=", "{:.3f}".format(average_latent_loss1)) if (epoch % save_epoch == 0) or (epoch == training_epochs - 1): torch.save(scan.state_dict(), '{}/scan_epoch_{}.pth'.format(exp, epoch)) data_manager = DataManager() data_manager.prepare() dae = DAE() vae = VAE() scan = SCAN() if use_cuda: dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth')) vae.load_state_dict(torch.load('save/vae/vae_epoch_2999.pth')) scan.load_state_dict(torch.load('save/scan/scan_epoch_1499.pth')) dae, vae, scan = dae.cuda(), vae.cuda(), scan.cuda() else: dae.load_state_dict( torch.load('save/dae/dae_epoch_2999.pth', map_location=lambda storage, loc: storage)) vae.load_state_dict( torch.load('save/vae/vae_epoch_2999.pth', map_location=lambda storage, loc: storage)) scan.load_state_dict( torch.load(exp + '/' + opt.load, map_location=lambda storage, loc: storage))
def evalrank(model_path, data_path=None, split='dev', fold5=False): checkpoint = torch.load(model_path) opt = checkpoint['opt'] print(opt) if data_path is not None: opt.data_path = data_path vocab = deserialize_vocab( os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) opt.vocab_size = len(vocab) captions_w = np.load(opt.caption_np + 'caption_np.npy') captions_w = torch.from_numpy(captions_w) captions_w = captions_w.cuda() model = SCAN(opt, captions_w) # load model state model.load_state_dict(checkpoint['model']) print('Loading dataset') data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size, opt.workers, opt) print('Computing results...') img_embs, cap_embs, cap_lens = encode_data(model, data_loader) print('Images: %d, Captions: %d' % (img_embs.shape[0] / 5, cap_embs.shape[0])) if not fold5: img_embs = np.array([img_embs[i] for i in range(0, len(img_embs), 5)]) start = time.time() if opt.cross_attn == 't2i': sims = shard_xattn_t2i(img_embs, cap_embs, cap_lens, opt, shard_size=128) elif opt.cross_attn == 'i2t': sims = shard_xattn_i2t(img_embs, cap_embs, cap_lens, opt, shard_size=128) elif opt.cross_attn == 'all': sims, label = shard_xattn_all(model, img_embs, cap_embs, cap_lens, opt, shard_size=128) else: raise NotImplementedError end = time.time() print("calculate similarity time:", end - start) np.save('sim_stage1', sims) r, rt = i2t(label, img_embs, cap_embs, cap_lens, sims, return_ranks=True) ri, rti = t2i(label, img_embs, cap_embs, cap_lens, sims, return_ranks=True) ar = (r[0] + r[1] + r[2]) / 3 ari = (ri[0] + ri[1] + ri[2]) / 3 rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] print("rsum: %.1f" % rsum) print("Average i2t Recall: %.1f" % ar) print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r) print("Average t2i Recall: %.1f" % ari) print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri) else: results = [] for i in range(5): img_embs_shard = img_embs[i * 5000:(i + 1) * 5000:5] cap_embs_shard = cap_embs[i * 5000:(i + 1) * 5000] cap_lens_shard = cap_lens[i * 5000:(i + 1) * 5000] start = time.time() if opt.cross_attn == 't2i': sims = shard_xattn_t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) elif opt.cross_attn == 'i2t': sims = shard_xattn_i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, opt, shard_size=128) else: raise NotImplementedError end = time.time() print("calculate similarity time:", end - start) r, rt0 = i2t(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) ri, rti0 = t2i(img_embs_shard, cap_embs_shard, cap_lens_shard, sims, return_ranks=True) print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) if i == 0: rt, rti = rt0, rti0 ar = (r[0] + r[1] + r[2]) / 3 ari = (ri[0] + ri[1] + ri[2]) / 3 rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) results += [list(r) + list(ri) + [ar, ari, rsum]] print("-----------------------------------") print("Mean metrics: ") mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) print("rsum: %.1f" % (mean_metrics[10] * 6)) print("Average i2t Recall: %.1f" % mean_metrics[11]) print("Image to text: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[:5]) print("Average t2i Recall: %.1f" % mean_metrics[12]) print("Text to image: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[5:10]) torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
def create_model(opt, ema=False): model = SCAN(opt, ema) return model
def train(self): model = SCAN(self.params) model.apply(init_xavier) model.load_state_dict(torch.load('models/model_weights_5.t7')) loss_function = MarginLoss(self.params.margin) if torch.cuda.is_available(): model = model.cuda() loss_function = loss_function.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=self.params.lr, weight_decay=self.params.wdecay) try: prev_best = 0 for epoch in range(self.params.num_epochs): iters = 1 losses = [] start_time = timer() num_of_mini_batches = len( self.data_loader.train_ids) // self.params.batch_size for (caption, mask, image, neg_cap, neg_mask, neg_image) in tqdm(self.data_loader.training_data_loader): # Sample according to hard negative mining caption, mask, image, neg_cap, neg_mask, neg_image = self.data_loader.hard_negative_mining( model, caption, mask, image, neg_cap, neg_mask, neg_image) model.train() optimizer.zero_grad() # forward pass. similarity = model(to_variable(caption), to_variable(mask), to_variable(image), False) similarity_neg_1 = model(to_variable(neg_cap), to_variable(neg_mask), to_variable(image), False) similarity_neg_2 = model(to_variable(caption), to_variable(mask), to_variable(neg_image), False) # Compute the loss, gradients, and update the parameters by calling optimizer.step() loss = loss_function(similarity, similarity_neg_1, similarity_neg_2) loss.backward() losses.append(loss.data.cpu().numpy()) if self.params.clip_value > 0: torch.nn.utils.clip_grad_norm(model.parameters(), self.params.clip_value) optimizer.step() # sys.stdout.write("[%d/%d] :: Training Loss: %f \r" % ( # iters, num_of_mini_batches, np.asscalar(np.mean(losses)))) # sys.stdout.flush() iters += 1 if epoch + 1 % self.params.step_size == 0: optim_state = optimizer.state_dict() optim_state['param_groups'][0]['lr'] = optim_state[ 'param_groups'][0]['lr'] / self.params.gamma optimizer.load_state_dict(optim_state) torch.save( model.state_dict(), self.params.model_dir + '/model_weights_{}.t7'.format(epoch + 1)) # Calculate r@k after each epoch if (epoch + 1) % self.params.validate_every == 0: r_at_1, r_at_5, r_at_10 = self.evaluator.recall( model, is_test=False) print( "Epoch {} : Training Loss: {:.5f}, R@1 : {}, R@5 : {}, R@10 : {}, Time elapsed {:.2f} mins" .format(epoch + 1, np.asscalar(np.mean(losses)), r_at_1, r_at_5, r_at_10, (timer() - start_time) / 60)) if r_at_1 > prev_best: print("Recall at 1 increased....saving weights !!") prev_best = r_at_1 torch.save( model.state_dict(), self.params.model_dir + 'best_model_weights_{}.t7'.format(epoch + 1)) else: print("Epoch {} : Training Loss: {:.5f}".format( epoch + 1, np.asscalar(np.mean(losses)))) except KeyboardInterrupt: print("Interrupted.. saving model !!!") torch.save(model.state_dict(), self.params.model_dir + '/model_weights_interrupt.t7')
def evalrank(model_path, run, data_path=None, split='dev', fold5=False, vocab_path="../vocab/", change=False): """ Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold cross-validation is done (only for MSCOCO). Otherwise, the full data is used for evaluation. """ # load model and options checkpoint = torch.load(model_path) opt = checkpoint['opt'] print(opt) # add because div_transform is not present in model # d = vars(opt) # d['tanh'] = True if data_path is not None: opt.data_path = data_path # load vocabulary used by the model vocab = deserialize_vocab("{}{}/{}_vocab_{}.json".format( vocab_path, opt.clothing, opt.data_name, opt.version)) opt.vocab_size = len(vocab) print(opt.vocab_size) # construct model model = SCAN(opt) # load model state model.load_state_dict(checkpoint['model']) if change: opt.clothing = "dresses" print('Loading dataset') data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size, opt.workers, opt) print('Computing results...') img_embs, cap_embs, cap_lens, freqs = encode_data(model, data_loader) print('Images: %d, Captions: %d' % (img_embs.shape[0], cap_embs.shape[0])) t2i_switch = True if opt.cross_attn == 't2i': sims, attn = shard_xattn_t2i(img_embs, cap_embs, cap_lens, freqs, opt, shard_size=128) elif opt.cross_attn == 'i2t': sims, attn = shard_xattn_i2t(img_embs, cap_embs, cap_lens, freqs, opt, shard_size=128) t2i_switch = False else: raise NotImplementedError # r = (r1, r2, r5, medr, meanr), rt= (ranks, top1) r, rt = i2t(img_embs, cap_embs, cap_lens, sims, return_ranks=True) ri, rti = t2i(img_embs, cap_embs, cap_lens, sims, return_ranks=True) ar = (r[0] + r[1] + r[2]) / 3 ari = (ri[0] + ri[1] + ri[2]) / 3 rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] print("rsum: %.1f" % rsum) print("Average i2t Recall: %.1f" % ar) print("Image to text: %.1f %.1f %.1f %.1f %.1f %.1f %.1f" % r) print("Average t2i Recall: %.1f" % ari) print("Text to image: %.1f %.1f %.1f %.1f %.1f %.1f %.1f" % ri) if opt.trans: save_dir = "plots_trans" else: save_dir = "plots_scan" if not os.path.exists(save_dir): os.makedirs(save_dir) torch.save({ 'rt': rt, 'rti': rti, "attn": attn, "t2i_switch": t2i_switch }, '{}/ranks_{}_{}.pth.tar'.format(save_dir, run, opt.version)) return rt, rti, attn, r, ri
def evalrank(model_path, data_path=None, split='dev', fold5=False): """ Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold cross-validation is done (only for MSCOCO). Otherwise, the full data is used for evaluation. """ # load model and options checkpoint = torch.load(model_path) opt = checkpoint['opt'] print(opt) if data_path is not None: opt.data_path = data_path # load vocabulary used by the model with open(os.path.join(opt.vocab_path, '%s_vocab.pkl' % opt.data_name), 'rb') as f: vocab = pickle.load(f) opt.vocab_size = len(vocab) # construct model model = SCAN(opt) # load model state model.load_state_dict(checkpoint['model']) print('Loading dataset') data_loader = get_test_loader(split, opt.data_name, vocab, opt.batch_size, opt.workers, opt) print('Computing results...') img_embs, cap_embs = encode_data(model, data_loader) print('Images: %d, Captions: %d' % (img_embs.shape[0] / 5, cap_embs.shape[0])) if not fold5: # no cross-validation, full evaluation r, rt = i2t(img_embs, cap_embs, measure=opt.measure, return_ranks=True) ri, rti = t2i(img_embs, cap_embs, measure=opt.measure, return_ranks=True) ar = (r[0] + r[1] + r[2]) / 3 ari = (ri[0] + ri[1] + ri[2]) / 3 rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] print("rsum: %.1f" % rsum) print("Average i2t Recall: %.1f" % ar) print("Image to text: %.1f %.1f %.1f %.1f %.1f" % r) print("Average t2i Recall: %.1f" % ari) print("Text to image: %.1f %.1f %.1f %.1f %.1f" % ri) else: # 5fold cross-validation, only for MSCOCO results = [] for i in range(5): r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], measure=opt.measure, return_ranks=True) print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f" % r) ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], measure=opt.measure, return_ranks=True) if i == 0: rt, rti = rt0, rti0 print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f" % ri) ar = (r[0] + r[1] + r[2]) / 3 ari = (ri[0] + ri[1] + ri[2]) / 3 rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) results += [list(r) + list(ri) + [ar, ari, rsum]] print("-----------------------------------") print("Mean metrics: ") mean_metrics = tuple(numpy.array(results).mean(axis=0).flatten()) print("rsum: %.1f" % (mean_metrics[10] * 6)) print("Average i2t Recall: %.1f" % mean_metrics[11]) print("Image to text: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[:5]) print("Average t2i Recall: %.1f" % mean_metrics[12]) print("Text to image: %.1f %.1f %.1f %.1f %.1f" % mean_metrics[5:10]) torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar')
def main(): # Hyper Parameters opt = opts.parse_opt() device_id = opt.gpuid device_count = len(str(device_id).split(",")) #assert device_count == 1 or device_count == 2 print("use GPU:", device_id, "GPUs_count", device_count, flush=True) os.environ['CUDA_VISIBLE_DEVICES']=str(device_id) device_id = 0 torch.cuda.set_device(0) # Load Vocabulary Wrapper vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name)) opt.vocab_size = len(vocab) # Load data loaders train_loader, val_loader = data.get_loaders( opt.data_name, vocab, opt.batch_size, opt.workers, opt) # Construct the model model = SCAN(opt) model.cuda() model = nn.DataParallel(model) # Loss and Optimizer criterion = ContrastiveLoss(opt=opt, margin=opt.margin, max_violation=opt.max_violation) mse_criterion = nn.MSELoss(reduction="batchmean") optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate) # optionally resume from a checkpoint if not os.path.exists(opt.model_name): os.makedirs(opt.model_name) start_epoch = 0 best_rsum = 0 if opt.resume: if os.path.isfile(opt.resume): print("=> loading checkpoint '{}'".format(opt.resume)) checkpoint = torch.load(opt.resume) start_epoch = checkpoint['epoch'] best_rsum = checkpoint['best_rsum'] model.load_state_dict(checkpoint['model']) print("=> loaded checkpoint '{}' (epoch {}, best_rsum {})" .format(opt.resume, start_epoch, best_rsum)) else: print("=> no checkpoint found at '{}'".format(opt.resume)) evalrank(model.module, val_loader, opt) print(opt, flush=True) # Train the Model for epoch in range(start_epoch, opt.num_epochs): message = "epoch: %d, model name: %s\n" % (epoch, opt.model_name) log_file = os.path.join(opt.logger_name, "performance.log") logging_func(log_file, message) print("model name: ", opt.model_name, flush=True) adjust_learning_rate(opt, optimizer, epoch) run_time = 0 for i, (images, captions, lengths, masks, ids, _) in enumerate(train_loader): start_time = time.time() model.train() optimizer.zero_grad() if device_count != 1: images = images.repeat(device_count,1,1) score = model(images, captions, lengths, masks, ids) loss = criterion(score) loss.backward() if opt.grad_clip > 0: clip_grad_norm_(model.parameters(), opt.grad_clip) optimizer.step() run_time += time.time() - start_time # validate at every val_step if i % 100 == 0: log = "epoch: %d; batch: %d/%d; loss: %.4f; time: %.4f" % (epoch, i, len(train_loader), loss.data.item(), run_time / 100) print(log, flush=True) run_time = 0 if (i + 1) % opt.val_step == 0: evalrank(model.module, val_loader, opt) print("-------- performance at epoch: %d --------" % (epoch)) # evaluate on validation set rsum = evalrank(model.module, val_loader, opt) #rsum = -100 filename = 'model_' + str(epoch) + '.pth.tar' # remember best R@ sum and save checkpoint is_best = rsum > best_rsum best_rsum = max(rsum, best_rsum) save_checkpoint({ 'epoch': epoch + 1, 'model': model.state_dict(), 'best_rsum': best_rsum, 'opt': opt, }, is_best, filename=filename, prefix=opt.model_name + '/')
if epoch % display_epoch == 0: print("Epoch:", '%04d' % (epoch + 1), "loss=", "{}".format(average_loss)) if (epoch % save_epoch == 0) or (epoch == training_epochs - 1): torch.save(recomb.state_dict(), '{}/recomb_epoch_{}.pth'.format(exp, epoch)) data_manager = DataManager() data_manager.prepare() dae = DAE() vae = VAE() scan = SCAN() recomb = Recombinator() if use_cuda: dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth')) vae.load_state_dict(torch.load('save/vae/vae_epoch_2999.pth')) scan.load_state_dict(torch.load('save/scan/scan_epoch_1499.pth')) dae, vae, scan, recomb = dae.cuda(), vae.cuda(), scan.cuda(), recomb.cuda() else: dae.load_state_dict( torch.load('save/dae/dae_epoch_2999.pth', map_location=lambda storage, loc: storage)) vae.load_state_dict( torch.load('save/vae/vae_epoch_2999.pth', map_location=lambda storage, loc: storage)) scan.load_state_dict(