def get_latent(args):
	device = torch.device(args.gpu)
	print("Loading embedding model...")
	with open(os.path.join(CONFIG.DATASET_PATH, args.target_dataset, 'word_embedding.p'), "rb") as f:
		text_embedding_model = cPickle.load(f)
	with open(os.path.join(CONFIG.DATASET_PATH, args.target_dataset, 'word_idx.json'), "r", encoding='utf-8') as f:
		word_idx = json.load(f)
	print("Loading embedding model completed")
	print("Loading dataset...")
	full_dataset = load_fullmultimodal_data(args, CONFIG, word2idx=word_idx[1])
	print("Loading dataset completed")
	full_loader = DataLoader(full_dataset, batch_size=args.batch_size, shuffle=False)
	
	# t1 = max_sentence_len + 2 * (args.filter_shape - 1)
	t1 = CONFIG.MAX_SENTENCE_LEN
	t2 = int(math.floor((t1 - args.filter_shape) / 2) + 1) # "2" means stride size
	t3 = int(math.floor((t2 - args.filter_shape) / 2) + 1)
	args.t3 = t3

	text_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(text_embedding_model))
	text_encoder = text_model.ConvolutionEncoder(text_embedding, t3, args.filter_size, args.filter_shape, args.encode_latent)
	imgseq_encoder = imgseq_model.RNNEncoder(args.image_embedding_dim, args.num_layer, args.encode_latent, bidirectional=True)
	multimodal_encoder = multimodal_model.MultimodalEncoder(text_encoder, imgseq_encoder, args.latent_size, args.normalize, args.add_latent)
	checkpoint = torch.load(os.path.join(CONFIG.CHECKPOINT_PATH, args.checkpoint), map_location=lambda storage, loc: storage)
	multimodal_encoder.load_state_dict(checkpoint['multimodal_encoder'])
	multimodal_encoder.to(device)
	multimodal_encoder.eval() 


	csv_name = 'latent_' + args.target_dataset
	if args.normalize:
		csv_name = csv_name + "_normalize"
	if args.add_latent:
		csv_name = csv_name + "_add_latent"
	if args.no_decode:
		csv_name = csv_name + "_no_decode"
	csv_name = csv_name + '.csv'
	#f_csv = open(os.path.join(CONFIG.CSV_PATH, 'latent_' + args.target_dataset + '.csv'), 'w', encoding='utf-8-sig')
	#wr = csv.writer(f_csv)
	short_code_list = []
	row_list = []
	for text_batch, imgseq_batch, short_code in tqdm(full_loader):
		torch.cuda.empty_cache()
		with torch.no_grad():	
			text_feature = Variable(text_batch).to(device)
			imgseq_feature = Variable(imgseq_batch).to(device)
		h = multimodal_encoder(text_feature, imgseq_feature)

		for _short_code, _h in zip(short_code, h):
			short_code_list.append(_short_code)
			row_list.append(_h.detach().cpu().numpy().tolist())
			# row = [_short_code] + _h.detach().cpu().numpy().tolist()
			# wr.writerow(row)
		del text_feature, imgseq_feature
	#f_csv.close()
	result_df = pd.DataFrame(data=row_list, index=short_code_list, columns=[i for i in range(args.latent_size)])
	result_df.index.name = "short_code"
	result_df.sort_index(inplace=True)
	result_df.to_csv(os.path.join(CONFIG.CSV_PATH, csv_name), encoding='utf-8-sig')
	print("Finish!!!")
Пример #2
0
def get_latent(args):
    device = torch.device(args.gpu)
    print("Loading embedding model...")
    image_embedding_model = models.__dict__[args.arch](pretrained=True)
    image_embedding_dim = image_embedding_model.fc.in_features
    args.image_embedding_dim = image_embedding_dim
    model_name = 'FASTTEXT_' + args.target_dataset + '.model'
    text_embedding_model = FastTextKeyedVectors.load(
        os.path.join(CONFIG.EMBEDDING_PATH, model_name))
    text_embedding_dim = text_embedding_model.vector_size
    args.text_embedding_dim = text_embedding_dim
    print("Building index...")
    indexer = AnnoyIndexer(text_embedding_model, 10)
    print("Loading embedding model completed")
    print("Loading dataset...")
    full_dataset = load_full_data(args,
                                  CONFIG,
                                  text_embedding_model,
                                  total=True)
    print("Loading dataset completed")
    full_loader = DataLoader(full_dataset,
                             batch_size=args.batch_size,
                             shuffle=False)

    # t1 = max_sentence_len + 2 * (args.filter_shape - 1)
    t1 = CONFIG.MAX_SENTENCE_LEN
    t2 = int(math.floor(
        (t1 - args.filter_shape) / 2) + 1)  # "2" means stride size
    t3 = int(math.floor((t2 - args.filter_shape) / 2) + 1)
    args.t3 = t3

    text_encoder = text_model.ConvolutionEncoder(text_embedding_dim, t3,
                                                 args.filter_size,
                                                 args.filter_shape,
                                                 args.latent_size)
    text_decoder = text_model.DeconvolutionDecoder(text_embedding_dim, t3,
                                                   args.filter_size,
                                                   args.filter_shape,
                                                   args.latent_size)
    imgseq_encoder = imgseq_model.RNNEncoder(image_embedding_dim,
                                             args.num_layer,
                                             args.latent_size,
                                             bidirectional=True)
    imgseq_decoder = imgseq_model.RNNDecoder(image_embedding_dim,
                                             args.num_layer,
                                             args.latent_size,
                                             bidirectional=True)
    checkpoint = torch.load(os.path.join(CONFIG.CHECKPOINT_PATH,
                                         args.checkpoint),
                            map_location=lambda storage, loc: storage)
    multimodal_encoder = multimodal_model.MultimodalEncoder(
        text_encoder, imgseq_encoder, args.latent_size)
    multimodal_encoder.load_state_dict(checkpoint['multimodal_encoder'])
    multimodal_encoder.to(device)
    multimodal_encoder.eval()

    f_csv = open(os.path.join(CONFIG.CSV_PATH, 'latent_features.csv'),
                 'w',
                 encoding='utf-8')
    wr = csv.writer(f_csv)
    for steps, (text_batch, imgseq_batch,
                short_code) in enumerate(full_loader):
        torch.cuda.empty_cache()
        with torch.no_grad():
            text_feature = Variable(text_batch).to(device)
            imgseq_feature = Variable(imgseq_batch).to(device)
        h = multimodal_encoder(text_feature, imgseq_feature)
        row = [short_code] + h.detach().cpu().numpy().tolist()
        wr.writerow(row)
        del text_feature, imgseq_feature
    f_csv.close()
    print("Finish!!!")
def get_latent(args):
    device = torch.device(args.gpu)
    print("Loading embedding model...")
    with open(
            os.path.join(CONFIG.DATASET_PATH, args.target_dataset,
                         'word_embedding.p'), "rb") as f:
        text_embedding_model = cPickle.load(f)
    with open(os.path.join(CONFIG.DATASET_PATH, args.target_dataset,
                           'word_idx.json'),
              "r",
              encoding='utf-8') as f:
        word_idx = json.load(f)
    print("Loading embedding model completed")
    print("Loading dataset...")
    full_dataset = load_fullmultimodal_data(args, CONFIG, word2idx=word_idx[1])
    print("Loading dataset completed")
    full_loader = DataLoader(full_dataset,
                             batch_size=args.batch_size,
                             shuffle=False)

    # t1 = max_sentence_len + 2 * (args.filter_shape - 1)
    t1 = CONFIG.MAX_SENTENCE_LEN
    t2 = int(math.floor(
        (t1 - args.filter_shape) / 2) + 1)  # "2" means stride size
    t3 = int(math.floor((t2 - args.filter_shape) / 2) + 1)
    args.t3 = t3

    text_embedding = nn.Embedding.from_pretrained(
        torch.FloatTensor(text_embedding_model))
    text_encoder = text_model.ConvolutionEncoder(text_embedding, t3,
                                                 args.filter_size,
                                                 args.filter_shape,
                                                 args.latent_size)
    text_decoder = text_model.DeconvolutionDecoder(text_embedding, args.tau,
                                                   t3, args.filter_size,
                                                   args.filter_shape,
                                                   args.latent_size, device)
    imgseq_encoder = imgseq_model.RNNEncoder(args.image_embedding_dim,
                                             args.num_layer,
                                             args.latent_size,
                                             bidirectional=True)
    imgseq_decoder = imgseq_model.RNNDecoder(CONFIG.MAX_SEQUENCE_LEN,
                                             args.image_embedding_dim,
                                             args.num_layer,
                                             args.latent_size,
                                             bidirectional=True)
    multimodal_encoder = multimodal_model.MultimodalEncoder(
        text_encoder, imgseq_encoder, args.latent_size)
    checkpoint = torch.load(os.path.join(CONFIG.CHECKPOINT_PATH,
                                         args.checkpoint),
                            map_location=lambda storage, loc: storage)
    multimodal_encoder.load_state_dict(checkpoint['multimodal_encoder'])
    multimodal_encoder.to(device)
    multimodal_encoder.eval()

    f_csv = open(os.path.join(CONFIG.CSV_PATH,
                              'latent_' + args.target_dataset + '.csv'),
                 'w',
                 encoding='utf-8-sig')
    wr = csv.writer(f_csv)
    for text_batch, imgseq_batch, short_code in tqdm(full_loader):
        torch.cuda.empty_cache()
        with torch.no_grad():
            text_feature = Variable(text_batch).to(device)
            imgseq_feature = Variable(imgseq_batch).to(device)
        h = multimodal_encoder(text_feature, imgseq_feature)

        for _short_code, _h in zip(short_code, h):
            row = [_short_code] + _h.detach().cpu().numpy().tolist()
            wr.writerow(row)
        del text_feature, imgseq_feature
    f_csv.close()
    print("Finish!!!")
def train_reconstruction(args):
    device = torch.device(args.gpu)
    print("Loading embedding model...")
    with open(
            os.path.join(CONFIG.DATASET_PATH, args.target_dataset,
                         'word_embedding.p'), "rb") as f:
        text_embedding_model = cPickle.load(f)
    with open(os.path.join(CONFIG.DATASET_PATH, args.target_dataset,
                           'word_idx.json'),
              "r",
              encoding='utf-8') as f:
        word_idx = json.load(f)
    print("Loading embedding model completed")
    print("Loading dataset...")
    train_dataset, val_dataset = load_multimodal_data(args,
                                                      CONFIG,
                                                      word2idx=word_idx[1])
    print("Loading dataset completed")
    train_loader, val_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=args.shuffle),\
             DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    # t1 = max_sentence_len + 2 * (args.filter_shape - 1)
    t1 = CONFIG.MAX_SENTENCE_LEN
    t2 = int(math.floor(
        (t1 - args.filter_shape) / 2) + 1)  # "2" means stride size
    t3 = int(math.floor((t2 - args.filter_shape) / 2) + 1)
    args.t3 = t3
    text_embedding = nn.Embedding.from_pretrained(
        torch.FloatTensor(text_embedding_model))

    text_encoder = text_model.ConvolutionEncoder(text_embedding, t3,
                                                 args.filter_size,
                                                 args.filter_shape,
                                                 args.encode_latent)
    imgseq_encoder = imgseq_model.RNNEncoder(args.image_embedding_dim,
                                             args.num_layer,
                                             args.encode_latent,
                                             bidirectional=True)
    text_decoder = text_model.DeconvolutionDecoder(text_embedding, args.tau,
                                                   t3, args.filter_size,
                                                   args.filter_shape,
                                                   args.decode_latent, device)
    imgseq_decoder = imgseq_model.RNNDecoder(CONFIG.MAX_SEQUENCE_LEN,
                                             args.image_embedding_dim,
                                             args.num_layer,
                                             args.decode_latent,
                                             bidirectional=True)

    if args.pretrained:
        text_encoder_checkpoint = torch.load(
            os.path.join(CONFIG.CHECKPOINT_PATH,
                         ("text_autoencoder_" + str(args.encode_latent) +
                          "_epoch_100.pt")),
            map_location=lambda storage, loc: storage)
        text_encoder.load_state_dict(text_encoder_checkpoint['text_encoder'])
        del text_encoder_checkpoint
        text_decoder_checkpoint = torch.load(
            os.path.join(CONFIG.CHECKPOINT_PATH,
                         ("text_autoencoder_" + str(args.decode_latent) +
                          "_epoch_100.pt")),
            map_location=lambda storage, loc: storage)
        text_decoder.load_state_dict(text_decoder_checkpoint['text_decoder'])
        del text_decoder_checkpoint
        imgseq_encoder_checkpoint = torch.load(
            os.path.join(CONFIG.CHECKPOINT_PATH,
                         ("imgseq_autoencoder_" + str(args.encode_latent) +
                          "_epoch_100.pt")),
            map_location=lambda storage, loc: storage)
        imgseq_encoder.load_state_dict(
            imgseq_encoder_checkpoint['imgseq_encoder'])
        del imgseq_encoder_checkpoint
        imgseq_decoder_checkpoint = torch.load(
            os.path.join(CONFIG.CHECKPOINT_PATH,
                         ("imgseq_autoencoder_" + str(args.decode_latent) +
                          "_epoch_100.pt")),
            map_location=lambda storage, loc: storage)
        imgseq_decoder.load_state_dict(
            imgseq_decoder_checkpoint['imgseq_decoder'])
        del imgseq_decoder_checkpoint

    multimodal_encoder = multimodal_model.MultimodalEncoder(
        text_encoder, imgseq_encoder, args.latent_size, args.normalize,
        args.add_latent)
    multimodal_decoder = multimodal_model.MultimodalDecoder(
        text_decoder, imgseq_decoder, args.latent_size,
        CONFIG.MAX_SEQUENCE_LEN, args.no_decode)

    if args.resume:
        print("Restart from checkpoint")
        checkpoint = torch.load(os.path.join(CONFIG.CHECKPOINT_PATH,
                                             args.resume),
                                map_location=lambda storage, loc: storage)
        start_epoch = checkpoint['epoch']
        multimodal_encoder.load_state_dict(checkpoint['multimodal_encoder'])
        multimodal_decoder.load_state_dict(checkpoint['multimodal_decoder'])
    else:
        print("Start from initial")
        start_epoch = 0

    multimodal_autoencoder = multimodal_model.MultimodalAutoEncoder(
        multimodal_encoder, multimodal_decoder)
    text_criterion = nn.NLLLoss().to(device)
    imgseq_criterion = nn.MSELoss().to(device)
    multimodal_autoencoder.to(device)

    optimizer = AdamW(multimodal_autoencoder.parameters(),
                      lr=1.,
                      weight_decay=args.weight_decay,
                      amsgrad=True)
    step_size = args.half_cycle_interval * len(train_loader)
    clr = cyclical_lr(step_size,
                      min_lr=args.lr,
                      max_lr=args.lr * args.lr_factor)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, [clr])
    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

    exp_name = "Multimodal autoencoder"
    if args.normalize:
        exp_name = exp_name + "normalize"
    if args.add_latent:
        exp_name = exp_name + "add_latent"
    if args.no_decode:
        exp_name = exp_name + "no_decode"
    exp = Experiment(exp_name, capture_io=False)

    for arg, value in vars(args).items():
        exp.param(arg, value)
    try:
        multimodal_autoencoder.train()

        for epoch in range(start_epoch, args.epochs):
            print("Epoch: {}".format(epoch))
            for steps, (text_batch, imgseq_batch) in enumerate(train_loader):
                torch.cuda.empty_cache()
                text_feature = Variable(text_batch).to(device)
                imgseq_feature = Variable(imgseq_batch).to(device)
                optimizer.zero_grad()
                text_prob, imgseq_feature_hat = multimodal_autoencoder(
                    text_feature, imgseq_feature)
                text_loss = text_criterion(text_prob.transpose(1, 2),
                                           text_feature)
                imgseq_loss = imgseq_criterion(imgseq_feature_hat,
                                               imgseq_feature)
                loss = text_loss + imgseq_loss
                del text_loss, imgseq_loss
                loss.backward()
                optimizer.step()
                scheduler.step()

                if (steps * args.batch_size) % args.log_interval == 0:
                    input_data = text_feature[0]
                    single_data = text_prob[0]
                    _, predict_index = torch.max(single_data, 1)
                    input_sentence = util.transform_idx2word(
                        input_data.detach().cpu().numpy(),
                        idx2word=word_idx[0])
                    predict_sentence = util.transform_idx2word(
                        predict_index.detach().cpu().numpy(),
                        idx2word=word_idx[0])
                    print("Epoch: {} at {} lr: {}".format(
                        epoch, str(datetime.datetime.now()),
                        str(scheduler.get_lr())))
                    print("Steps: {}".format(steps))
                    print("Loss: {}".format(loss.detach().item()))
                    print("Input Sentence:")
                    print(input_sentence)
                    print("Output Sentence:")
                    print(predict_sentence)
                    del input_data, single_data, _, predict_index
                del text_feature, text_prob, imgseq_feature, imgseq_feature_hat, loss

            exp.log("\nEpoch: {} at {} lr: {}".format(
                epoch, str(datetime.datetime.now()), str(scheduler.get_lr())))
            _avg_text_loss, _avg_imgseq_loss, _avg_loss, _rouge_1, _rouge_2 = eval_reconstruction_with_rouge(
                multimodal_autoencoder, word_idx[0], text_criterion,
                imgseq_criterion, val_loader, device)
            exp.log(
                "\nEvaluation - text_loss: {} imgseq_loss: {} loss: {}  Rouge1: {} Rouge2: {}"
                .format(_avg_text_loss, _avg_imgseq_loss, _avg_loss, _rouge_1,
                        _rouge_2))

            save_name = "multimodal_autoencoder"
            if args.normalize:
                save_name = save_name + "_normalize"
            if args.add_latent:
                save_name = save_name + "_add_latent"
            if args.no_decode:
                save_name = save_name + "_no_decode"

            util.save_models(
                {
                    'epoch': epoch + 1,
                    'multimodal_encoder': multimodal_encoder.state_dict(),
                    'multimodal_decoder': multimodal_decoder.state_dict(),
                    'avg_loss': _avg_loss,
                    'Rouge1:': _rouge_1,
                    'Rouge2': _rouge_2,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict()
                }, CONFIG.CHECKPOINT_PATH, save_name)

        print("Finish!!!")

    finally:
        exp.end()