def main(args):
    # read token maps
    etokens, etoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "entity_token_id_map.txt"))
    rtokens, rtoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "relation_token_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt"))
    _, rm_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt"))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    #train code (+1 for unk token)
    if args.model == "complex":
        if args.separate_lstms:
            model = complexLSTM_2(
                len(etoken_map) + 1,
                len(rtoken_map) + 1,
                args.embedding_dim,
                initial_token_embedding=args.initial_token_embedding,
                entity_tokens=etokens,
                relation_tokens=rtokens,
                lstm_dropout=args.lstm_dropout)
        else:
            model = complexLSTM(
                len(etoken_map) + 1,
                len(rtoken_map) + 1,
                args.embedding_dim,
                initial_token_embedding=args.initial_token_embedding,
                entity_tokens=etokens,
                relation_tokens=rtokens,
                lstm_dropout=args.lstm_dropout)
    elif args.model == "rotate":
        model = rotatELSTM(
            len(etoken_map) + 1,
            len(rtoken_map) + 1,
            args.embedding_dim,
            initial_token_embedding=args.initial_token_embedding,
            entity_tokens=etokens,
            relation_tokens=rtokens,
            gamma=args.gamma_rotate,
            lstm_dropout=args.lstm_dropout)
    if (args.do_train):
        optimizer = torch.optim.Adagrad(model.parameters(),
                                        lr=args.learning_rate,
                                        weight_decay=args.weight_decay)

        if (args.resume):
            print("Resuming from:", args.resume)
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            #Load other things too if required

        model.train()
        if "olpbench" in args.data_dir:
            train_kb = kb(os.path.join(
                args.data_dir,
                "train_data_{}.txt".format(args.train_data_type)),
                          em_map=em_map,
                          rm_map=rm_map)
            # train_kb = kb(os.path.join(args.data_dir,"train_data_thorough_r_sorted.txt"), em_map = em_map, rm_map = rm_map)
            # train_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map)

        else:
            train_kb = kb(os.path.join(args.data_dir, "train.txt"),
                          em_map=em_map,
                          rm_map=rm_map)

        train_data = Dataset(train_kb.triples)
        train_sampler = RandomSampler(train_data, replacement=False)
        #train_sampler = SequentialSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        # crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean')
        BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum')

        for epoch in tqdm(range(0, args.num_train_epochs), desc="epoch"):
            iteration = 0
            for train_e1_batch, train_r_batch, train_e2_batch in tqdm(
                    train_dataloader, desc="Train dataloader"):
                # skip this batch
                if (random.random() < args.skip_train_prob):
                    continue
                batch_size = len(train_e1_batch)
                train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(
                    train_e1_batch, etoken_map, maxlen=args.max_seq_length)
                train_r_mention_tensor, train_r_lengths = convert_string_to_indices(
                    train_r_batch, rtoken_map, maxlen=args.max_seq_length)
                train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(
                    train_e2_batch, etoken_map, maxlen=args.max_seq_length)

                train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(
                ), train_e1_lengths.cuda()
                train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(
                ), train_r_lengths.cuda()
                train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(
                ), train_e2_lengths.cuda()

                e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                    train_e1_mention_tensor, 0, train_e1_lengths)
                r_real_lstm, r_img_lstm = model.get_mention_embedding(
                    train_r_mention_tensor, 1, train_r_lengths)
                e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                    train_e2_mention_tensor, 0, train_e2_lengths)
                #tail
                simi_t = model.complex_score_e1_r_with_all_ementions(
                    e1_real_lstm, e1_img_lstm, r_real_lstm, r_img_lstm,
                    e2_real_lstm, e2_img_lstm)
                #head
                simi_h = model.complex_score_e2_r_with_all_ementions(
                    e2_real_lstm, e2_img_lstm, r_real_lstm, r_img_lstm,
                    e1_real_lstm, e1_img_lstm)
                # change the loss suitably
                target = torch.eye(batch_size).cuda()
                # import pdb
                # pdb.set_trace()
                loss_t = BCEloss(simi_t.view(-1), target.view(-1))
                loss_h = BCEloss(simi_h.view(-1), target.view(-1))
                loss = (loss_h + loss_t) / 2
                loss /= target.size(0) * target.size(1)

                # Do the routine
                optimizer.zero_grad()
                loss.backward()
                #gradient clip?
                optimizer.step()

                if (iteration % args.print_loss_every == 0):
                    print("Current loss(avg, tail, head):", loss.item(),
                          loss_t.item(), loss_h.item())
                iteration += 1
            if (epoch % args.save_model_every == 0 and epoch != 0):
                utils.save_checkpoint(
                    {
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    args.output_dir + "/checkpoint_epoch_{}".format(epoch + 1))

    if args.do_eval:
        #eval code
        metrics = {}
        metrics['mr'] = 0
        metrics['mrr'] = 0
        metrics['hits1'] = 0
        metrics['hits10'] = 0
        metrics['hits50'] = 0
        metrics['mr_t'] = 0
        metrics['mrr_t'] = 0
        metrics['hits1_t'] = 0
        metrics['hits10_t'] = 0
        metrics['hits50_t'] = 0
        metrics['mr_h'] = 0
        metrics['mrr_h'] = 0
        metrics['hits1_h'] = 0
        metrics['hits10_h'] = 0
        metrics['hits50_h'] = 0

        if (args.resume and not args.do_train):
            print("Resuming from:", args.resume)
            checkpoint = torch.load(args.resume,
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(checkpoint['state_dict'])

        model.eval()

        # get embeddings for all entity mentions
        entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(
            entity_mentions,
            etoken_map,
            maxlen=args.max_seq_length,
            use_tqdm=True)
        entity_mentions_tensor = entity_mentions_tensor.cuda()
        entity_mentions_lengths = entity_mentions_lengths.cuda()

        ementions_real_lis = []
        ementions_img_lis = []
        split = 100  #cant fit all in gpu together. hence split
        with torch.no_grad():
            for i in tqdm(
                    range(0, len(entity_mentions_tensor),
                          len(entity_mentions_tensor) // split)):
                data = entity_mentions_tensor[i:i +
                                              len(entity_mentions_tensor) //
                                              split, :]
                data_lengths = entity_mentions_lengths[
                    i:i + len(entity_mentions_tensor) // split]
                ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding(
                    data, 0, data_lengths)
                # a = model.Et_im(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:])
                # b = model.Et_re(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:])

                # a_lstm,_ = model.lstm(a)
                # a_lstm = a_lstm[:,-1,:]

                # b_lstm,_ = model.lstm(b)
                # b_lstm = b_lstm[:,-1,:]

                ementions_real_lis.append(ementions_real_lstm.cpu())
                ementions_img_lis.append(ementions_img_lstm.cpu())
        del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm
        torch.cuda.empty_cache()
        ementions_real = torch.cat(ementions_real_lis).cuda()
        ementions_img = torch.cat(ementions_img_lis).cuda()
        ########################################################################
        if "olpbench" in args.data_dir:
            # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map)
            test_kb = kb(os.path.join(args.data_dir, "test_data.txt"),
                         em_map=em_map,
                         rm_map=rm_map)
        else:
            test_kb = kb(os.path.join(args.data_dir, "test.txt"),
                         em_map=em_map,
                         rm_map=rm_map)

        print("Loading all_known pickled data...(takes times since large)")
        all_known_e2 = {}
        all_known_e1 = {}
        all_known_e2, all_known_e1 = pickle.load(
            open(
                os.path.join(
                    args.data_dir,
                    "all_knowns_{}_linked.pkl".format(args.train_data_type)),
                "rb"))

        test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(
            test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length)
        test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(
            test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length)
        test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(
            test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length)

        # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
        indices = torch.Tensor(
            range(len(test_kb.triples))
        )  #indices would be used to fetch alternative answers while evaluating
        test_data = TensorDataset(indices, test_e1_tokens_tensor,
                                  test_r_tokens_tensor, test_e2_tokens_tensor,
                                  test_e1_tokens_lengths,
                                  test_r_tokens_lengths,
                                  test_e2_tokens_lengths)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.eval_batch_size)
        split_dim_for_eval = 1
        if (args.embedding_dim >= 256 and "olpbench" in args.data_dir
                and "rotat" in args.model):
            split_dim_for_eval = 4
        if (args.embedding_dim >= 512 and "olpbench" in args.data_dir):
            split_dim_for_eval = 4
        if (args.embedding_dim >= 512 and "olpbench" in args.data_dir
                and "rotat" in args.model):
            split_dim_for_eval = 6
        split_dim_for_eval = 1
        for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(
                test_dataloader, desc="Test dataloader"):
            print(metrics)
            test_e1_tokens, test_e1_lengths = test_e1_tokens.to(
                device), test_e1_lengths.to(device)
            test_r_tokens, test_r_lengths = test_r_tokens.to(
                device), test_r_lengths.to(device)
            test_e2_tokens, test_e2_lengths = test_e2_tokens.to(
                device), test_e2_lengths.to(device)
            with torch.no_grad():
                e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                    test_e1_tokens, 0, test_e1_lengths)
                r_real_lstm, r_img_lstm = model.get_mention_embedding(
                    test_r_tokens, 1, test_r_lengths)
                e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                    test_e2_tokens, 0, test_e2_lengths)

            for count in tqdm(range(index.shape[0]), desc="Evaluating"):
                # breakpoint()
                this_e1_real = e1_real_lstm[count].unsqueeze(0)
                this_e1_img = e1_img_lstm[count].unsqueeze(0)
                this_r_real = r_real_lstm[count].unsqueeze(0)
                this_r_img = r_img_lstm[count].unsqueeze(0)
                this_e2_real = e2_real_lstm[count].unsqueeze(0)
                this_e2_img = e2_img_lstm[count].unsqueeze(0)
                # import pdb
                # pdb.set_trace()
                simi_t = model.complex_score_e1_r_with_all_ementions(
                    this_e1_real,
                    this_e1_img,
                    this_r_real,
                    this_r_img,
                    ementions_real,
                    ementions_img,
                    split=split_dim_for_eval).squeeze(0)
                simi_h = model.complex_score_e2_r_with_all_ementions(
                    this_e2_real,
                    this_e2_img,
                    this_r_real,
                    this_r_img,
                    ementions_real,
                    ementions_img,
                    split=split_dim_for_eval).squeeze(0)
                # get known answers for filtered ranking
                ind = index[count]
                this_correct_mentions_e2 = test_kb.e2_all_answers[int(
                    ind.item())]
                this_correct_mentions_e1 = test_kb.e1_all_answers[int(
                    ind.item())]

                all_correct_mentions_e2 = all_known_e2.get(
                    (em_map[test_kb.triples[int(ind.item())][0]],
                     rm_map[test_kb.triples[int(ind.item())][1]]), [])
                all_correct_mentions_e1 = all_known_e1.get(
                    (em_map[test_kb.triples[int(ind.item())][2]],
                     rm_map[test_kb.triples[int(ind.item())][1]]), [])

                # compute metrics
                best_score = simi_t[this_correct_mentions_e2].max()
                simi_t[
                    all_correct_mentions_e2] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi_t.ge(best_score).float()
                equal = simi_t.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0

                metrics['mr_t'] += rank
                metrics['mrr_t'] += 1.0 / rank
                metrics['hits1_t'] += rank.le(1).float()
                metrics['hits10_t'] += rank.le(10).float()
                metrics['hits50_t'] += rank.le(50).float()

                best_score = simi_h[this_correct_mentions_e1].max()
                simi_h[
                    all_correct_mentions_e1] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi_h.ge(best_score).float()
                equal = simi_h.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0
                metrics['mr_h'] += rank
                metrics['mrr_h'] += 1.0 / rank
                metrics['hits1_h'] += rank.le(1).float()
                metrics['hits10_h'] += rank.le(10).float()
                metrics['hits50_h'] += rank.le(50).float()

                metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2
                metrics['mrr'] = (metrics['mrr_h'] + metrics['mrr_t']) / 2
                metrics['hits1'] = (metrics['hits1_h'] +
                                    metrics['hits1_t']) / 2
                metrics['hits10'] = (metrics['hits10_h'] +
                                     metrics['hits10_t']) / 2
                metrics['hits50'] = (metrics['hits50_h'] +
                                     metrics['hits50_t']) / 2

        for key in metrics:
            metrics[key] = metrics[key] / len(test_kb.triples)
        print(metrics)
Exemple #2
0
def main(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # for batch in train_loader:
    # 	inputs, \
    #        normalizer_loss, \
    #        normalizer_metric, \
    #        labels, \
    #        label_ids, \
    #        filter_mask, \
    #        batch_shared_entities = train_data.input_and_labels_to_device(
    #            batch,
    #            training=True,
    #            device=train_data.device
    #        )
    # 	import pdb
    # 	pdb.set_trace()

    # read token maps
    etokens, etoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "entity_token_id_map.txt"))
    rtokens, rtoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "relation_token_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt"))
    relation_mentions, rm_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt"))

    # create entity_token_indices and entity_lengths
    # [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...]
    # [length of entity 0, length of entity 1, length of entity 2, ...]
    # entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices(entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True)
    # relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=args.max_seq_length, use_tqdm=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    #train code (+1 for unk token)
    if args.model == "complex":
        if args.separate_lstms:
            model = complexLSTM_2(
                len(etoken_map) + 1,
                len(rtoken_map) + 1,
                args.embedding_dim,
                initial_token_embedding=args.initial_token_embedding,
                entity_tokens=etokens,
                relation_tokens=rtokens,
                lstm_dropout=args.lstm_dropout)
        else:
            model = complexLSTM(
                len(etoken_map) + 1,
                len(rtoken_map) + 1,
                args.embedding_dim,
                initial_token_embedding=args.initial_token_embedding,
                entity_tokens=etokens,
                relation_tokens=rtokens,
                lstm_dropout=args.lstm_dropout)
    elif args.model == "rotate":
        model = rotatELSTM(
            len(etoken_map) + 1,
            len(rtoken_map) + 1,
            args.embedding_dim,
            initial_token_embedding=args.initial_token_embedding,
            entity_tokens=etokens,
            relation_tokens=rtokens,
            gamma=args.gamma_rotate,
            lstm_dropout=args.lstm_dropout)

    if args.do_eval:
        best_model = -1
        best_metrics = None
        if "olpbench" in args.data_dir:
            # test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map)
            test_kb = kb(os.path.join(args.data_dir, "test_data.txt"),
                         em_map=em_map,
                         rm_map=rm_map)
        else:
            test_kb = kb(os.path.join(args.data_dir, "test.txt"),
                         em_map=em_map,
                         rm_map=rm_map)

        print("Loading all_known pickled data...(takes times since large)")
        all_known_e2 = {}
        all_known_e1 = {}
        all_known_e2, all_known_e1 = pickle.load(
            open(
                os.path.join(
                    args.data_dir,
                    "all_knowns_{}_linked.pkl".format(args.train_data_type)),
                "rb"))
        models = os.listdir("models/author_data_2lstm_thorough")
        for model_path in tqdm(models):
            try:
                model_path = os.path.join("models/author_data_2lstm_thorough",
                                          model_path)
                #eval code
                metrics = {}
                metrics['mr'] = 0
                metrics['mrr'] = 0
                metrics['hits1'] = 0
                metrics['hits10'] = 0
                metrics['hits50'] = 0
                metrics['mr_t'] = 0
                metrics['mrr_t'] = 0
                metrics['hits1_t'] = 0
                metrics['hits10_t'] = 0
                metrics['hits50_t'] = 0
                metrics['mr_h'] = 0
                metrics['mrr_h'] = 0
                metrics['hits1_h'] = 0
                metrics['hits10_h'] = 0
                metrics['hits50_h'] = 0

                checkpoint = torch.load(
                    model_path, map_location=lambda storage, loc: storage)
                model.load_state_dict(checkpoint['state_dict'])

                model.eval()

                # get embeddings for all entity mentions
                entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(
                    entity_mentions,
                    etoken_map,
                    maxlen=args.max_seq_length,
                    use_tqdm=False)
                entity_mentions_tensor = entity_mentions_tensor.cuda()
                entity_mentions_lengths = entity_mentions_lengths.cuda()

                ementions_real_lis = []
                ementions_img_lis = []
                split = 100  #cant fit all in gpu together. hence split
                with torch.no_grad():
                    for i in range(0, len(entity_mentions_tensor),
                                   len(entity_mentions_tensor) // split):
                        data = entity_mentions_tensor[
                            i:i + len(entity_mentions_tensor) // split, :]
                        data_lengths = entity_mentions_lengths[
                            i:i + len(entity_mentions_tensor) // split]
                        ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding(
                            data, 0, data_lengths)
                        ementions_real_lis.append(ementions_real_lstm.cpu())
                        ementions_img_lis.append(ementions_img_lstm.cpu())
                del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm
                torch.cuda.empty_cache()
                ementions_real = torch.cat(ementions_real_lis).cuda()
                ementions_img = torch.cat(ementions_img_lis).cuda()
                ########################################################################

                test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(
                    test_kb.triples[:, 0],
                    etoken_map,
                    maxlen=args.max_seq_length)
                test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(
                    test_kb.triples[:, 1],
                    rtoken_map,
                    maxlen=args.max_seq_length)
                test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(
                    test_kb.triples[:, 2],
                    etoken_map,
                    maxlen=args.max_seq_length)

                # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
                indices = torch.Tensor(
                    range(len(test_kb.triples))
                )  #indices would be used to fetch alternative answers while evaluating
                test_data = TensorDataset(indices, test_e1_tokens_tensor,
                                          test_r_tokens_tensor,
                                          test_e2_tokens_tensor,
                                          test_e1_tokens_lengths,
                                          test_r_tokens_lengths,
                                          test_e2_tokens_lengths)
                test_sampler = SequentialSampler(test_data)
                test_dataloader = DataLoader(test_data,
                                             sampler=test_sampler,
                                             batch_size=args.eval_batch_size)
                split_dim_for_eval = 1
                if (args.embedding_dim >= 256 and "olpbench" in args.data_dir
                        and "rotat" in args.model):
                    split_dim_for_eval = 4
                if (args.embedding_dim >= 512 and "olpbench" in args.data_dir):
                    split_dim_for_eval = 4
                if (args.embedding_dim >= 512 and "olpbench" in args.data_dir
                        and "rotat" in args.model):
                    split_dim_for_eval = 6
                split_dim_for_eval = 1
                for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in test_dataloader:
                    test_e1_tokens, test_e1_lengths = test_e1_tokens.to(
                        device), test_e1_lengths.to(device)
                    test_r_tokens, test_r_lengths = test_r_tokens.to(
                        device), test_r_lengths.to(device)
                    test_e2_tokens, test_e2_lengths = test_e2_tokens.to(
                        device), test_e2_lengths.to(device)
                    with torch.no_grad():
                        e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                            test_e1_tokens, 0, test_e1_lengths)
                        r_real_lstm, r_img_lstm = model.get_mention_embedding(
                            test_r_tokens, 1, test_r_lengths)
                        e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                            test_e2_tokens, 0, test_e2_lengths)

                    for count in range(index.shape[0]):
                        # breakpoint()
                        this_e1_real = e1_real_lstm[count].unsqueeze(0)
                        this_e1_img = e1_img_lstm[count].unsqueeze(0)
                        this_r_real = r_real_lstm[count].unsqueeze(0)
                        this_r_img = r_img_lstm[count].unsqueeze(0)
                        this_e2_real = e2_real_lstm[count].unsqueeze(0)
                        this_e2_img = e2_img_lstm[count].unsqueeze(0)
                        simi_t = model.complex_score_e1_r_with_all_ementions(
                            this_e1_real,
                            this_e1_img,
                            this_r_real,
                            this_r_img,
                            ementions_real,
                            ementions_img,
                            split=split_dim_for_eval).squeeze(0)
                        simi_h = model.complex_score_e2_r_with_all_ementions(
                            this_e2_real,
                            this_e2_img,
                            this_r_real,
                            this_r_img,
                            ementions_real,
                            ementions_img,
                            split=split_dim_for_eval).squeeze(0)
                        # get known answers for filtered ranking
                        ind = index[count]
                        this_correct_mentions_e2 = test_kb.e2_all_answers[int(
                            ind.item())]
                        this_correct_mentions_e1 = test_kb.e1_all_answers[int(
                            ind.item())]

                        all_correct_mentions_e2 = all_known_e2.get(
                            (em_map[test_kb.triples[int(ind.item())][0]],
                             rm_map[test_kb.triples[int(ind.item())][1]]), [])
                        all_correct_mentions_e1 = all_known_e1.get(
                            (em_map[test_kb.triples[int(ind.item())][2]],
                             rm_map[test_kb.triples[int(ind.item())][1]]), [])

                        # compute metrics
                        best_score = simi_t[this_correct_mentions_e2].max()
                        simi_t[
                            all_correct_mentions_e2] = -20000000  # MOST NEGATIVE VALUE
                        greatereq = simi_t.ge(best_score).float()
                        equal = simi_t.eq(best_score).float()
                        rank = greatereq.sum() + 1 + equal.sum() / 2.0

                        metrics['mr_t'] += rank
                        metrics['mrr_t'] += 1.0 / rank
                        metrics['hits1_t'] += rank.le(1).float()
                        metrics['hits10_t'] += rank.le(10).float()
                        metrics['hits50_t'] += rank.le(50).float()

                        best_score = simi_h[this_correct_mentions_e1].max()
                        simi_h[
                            all_correct_mentions_e1] = -20000000  # MOST NEGATIVE VALUE
                        greatereq = simi_h.ge(best_score).float()
                        equal = simi_h.eq(best_score).float()
                        rank = greatereq.sum() + 1 + equal.sum() / 2.0
                        metrics['mr_h'] += rank
                        metrics['mrr_h'] += 1.0 / rank
                        metrics['hits1_h'] += rank.le(1).float()
                        metrics['hits10_h'] += rank.le(10).float()
                        metrics['hits50_h'] += rank.le(50).float()

                        metrics['mr'] = (metrics['mr_h'] + metrics['mr_t']) / 2
                        metrics['mrr'] = (metrics['mrr_h'] +
                                          metrics['mrr_t']) / 2
                        metrics['hits1'] = (metrics['hits1_h'] +
                                            metrics['hits1_t']) / 2
                        metrics['hits10'] = (metrics['hits10_h'] +
                                             metrics['hits10_t']) / 2
                        metrics['hits50'] = (metrics['hits50_h'] +
                                             metrics['hits50_t']) / 2

                for key in metrics:
                    metrics[key] = metrics[key] / len(test_kb.triples)
                if best_metrics == None or best_metrics['hits1'] < metrics[
                        'hits1']:
                    best_model = model_path
                    best_metrics = metrics
                print("best_hits1:", best_metrics['hits1'])
            except:
                continue
        print(best_metrics)
        print(best_model)
    etokens, etoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "entity_token_id_map.txt"))
    rtokens, rtoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "relation_token_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt"))
    relation_mentions, rm_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt"))

    if args.model == "complex":
        model = complexLSTM(len(etoken_map) + 1,
                            len(rtoken_map) + 1,
                            args.embedding_dim,
                            initial_token_embedding=None,
                            entity_tokens=etokens,
                            relation_tokens=rtokens,
                            lstm_dropout=args.lstm_dropout)
    elif args.model == "rotate":
        model = rotatELSTM(len(etoken_map) + 1,
                           len(rtoken_map) + 1,
                           args.embedding_dim,
                           initial_token_embedding=None,
                           entity_tokens=etokens,
                           relation_tokens=rtokens,
                           gamma=args.gamma_rotate,
                           lstm_dropout=args.lstm_dropout)
    if args.resume:
        print("Resuming from:", args.resume)
        checkpoint = torch.load(args.resume)
def main(args):
    hits_1_triple = []
    hits_1_correct_answers = []
    hits_1_model_top10 = []
    hits_1_evidence = []
    baseline_tail_hits1_indices = set([
        36, 91, 95, 101, 119, 158, 282, 397, 638, 728, 740, 763, 914, 959, 972,
        992, 1184, 1478, 1669, 1686, 1732, 1795, 1796, 1822, 1826, 1845, 1924,
        1939, 1943, 2055, 2178, 2317, 2319, 2325, 2482, 2513, 2589, 2627, 2674,
        2736, 2862, 2985, 3049, 3311, 3327, 3491, 3660, 3728, 3817, 3818, 4111,
        4263, 4387, 4437, 4438, 4452, 4525, 4591, 4670, 4856, 5114, 5159, 5318,
        5587, 5851, 5857, 5893, 5925, 5942, 5990, 6056, 6079, 6119, 6172, 6195,
        6211, 6228, 6262, 6267, 6460, 6491, 6509, 6584, 6676, 6699, 6862, 6982,
        7057, 7078, 7084, 7221, 7597, 7733, 7837, 8045, 8278, 8326, 8380, 8433,
        8453, 8479, 8534, 8540, 8742, 8813, 8860, 8906, 8930, 9234, 9333, 9500,
        9535, 9589, 9663, 9803, 9809, 9866, 9999
    ])
    baseline_correct = 0
    # nothits_50_triple = []
    # nothits_50_correct_answers = []
    # nothits_50_model_top10 = []

    injected_rels = kb(args.evidence_file, em_map=None,
                       rm_map=None).triples[:, 1].reshape(-1, args.n_times)

    # read token maps
    etokens, etoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "entity_token_id_map.txt"))
    rtokens, rtoken_map = utils.get_tokens_map(
        os.path.join(args.data_dir, "mapped_to_ids",
                     "relation_token_id_map.txt"))
    entity_mentions, em_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "entity_id_map.txt"))
    _, rm_map = utils.read_mentions(
        os.path.join(args.data_dir, "mapped_to_ids", "relation_id_map.txt"))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    #train code (+1 for unk token)
    model = complexLSTM(len(etoken_map) + 1,
                        len(rtoken_map) + 1,
                        args.embedding_dim,
                        initial_token_embedding=args.initial_token_embedding,
                        entity_tokens=etokens,
                        relation_tokens=rtokens,
                        lstm_dropout=0)

    if (args.resume):
        print("Resuming from:", args.resume)
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])

    model.eval()

    # get embeddings for all entity mentions
    entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(
        entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True)
    entity_mentions_tensor = entity_mentions_tensor.cuda()
    entity_mentions_lengths = entity_mentions_lengths.cuda()

    ementions_real_lis = []
    ementions_img_lis = []
    split = 100  #cant fit all in gpu together. hence split
    with torch.no_grad():
        for i in tqdm(
                range(0, len(entity_mentions_tensor),
                      len(entity_mentions_tensor) // split)):
            data = entity_mentions_tensor[i:i + len(entity_mentions_tensor) //
                                          split, :]
            data_lengths = entity_mentions_lengths[
                i:i + len(entity_mentions_tensor) // split]
            ementions_real_lstm, ementions_img_lstm = model.get_mention_embedding(
                data, 0, data_lengths)

            ementions_real_lis.append(ementions_real_lstm.cpu())
            ementions_img_lis.append(ementions_img_lstm.cpu())
    del entity_mentions_tensor, ementions_real_lstm, ementions_img_lstm
    torch.cuda.empty_cache()
    ementions_real = torch.cat(ementions_real_lis).cuda()
    ementions_img = torch.cat(ementions_img_lis).cuda()
    ########################################################################

    if "olpbench" in args.data_dir:
        test_kb = kb(os.path.join(args.data_dir, "test_data.txt"),
                     em_map=em_map,
                     rm_map=rm_map)
    else:
        test_kb = kb(os.path.join(args.data_dir, "test.txt"),
                     em_map=em_map,
                     rm_map=rm_map)
    print("Loading all_known pickled data...(takes times since large)")
    all_known_e2 = {}
    all_known_e1 = {}
    all_known_e2, all_known_e1 = pickle.load(
        open(os.path.join(args.data_dir, "all_knowns_thorough_linked.pkl"),
             "rb"))

    test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(
        test_kb.triples[:, 0], etoken_map, maxlen=args.max_seq_length)
    test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(
        test_kb.triples[:, 1], rtoken_map, maxlen=args.max_seq_length)
    test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(
        test_kb.triples[:, 2], etoken_map, maxlen=args.max_seq_length)

    # e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
    indices = torch.Tensor(
        range(len(test_kb.triples))
    )  #indices would be used to fetch alternative answers while evaluating
    test_data = TensorDataset(indices, test_e1_tokens_tensor,
                              test_r_tokens_tensor, test_e2_tokens_tensor,
                              test_e1_tokens_lengths, test_r_tokens_lengths,
                              test_e2_tokens_lengths)
    test_sampler = SequentialSampler(test_data)
    test_dataloader = DataLoader(test_data,
                                 sampler=test_sampler,
                                 batch_size=args.eval_batch_size)
    split_dim_for_eval = 1
    if (args.embedding_dim >= 512 and "olpbench" in args.data_dir):
        split_dim_for_eval = 4
    for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(
            test_dataloader, desc="Test dataloader"):
        test_e1_tokens, test_e1_lengths = test_e1_tokens.to(
            device), test_e1_lengths.to(device)
        test_r_tokens, test_r_lengths = test_r_tokens.to(
            device), test_r_lengths.to(device)
        test_e2_tokens, test_e2_lengths = test_e2_tokens.to(
            device), test_e2_lengths.to(device)
        with torch.no_grad():
            e1_real_lstm, e1_img_lstm = model.get_mention_embedding(
                test_e1_tokens, 0, test_e1_lengths)
            r_real_lstm, r_img_lstm = model.get_mention_embedding(
                test_r_tokens, 1, test_r_lengths)
            e2_real_lstm, e2_img_lstm = model.get_mention_embedding(
                test_e2_tokens, 0, test_e2_lengths)

        for count in tqdm(range(index.shape[0]), desc="Evaluating"):
            this_e1_real = e1_real_lstm[count].unsqueeze(0)
            this_e1_img = e1_img_lstm[count].unsqueeze(0)
            this_r_real = r_real_lstm[count].unsqueeze(0)
            this_r_img = r_img_lstm[count].unsqueeze(0)
            this_e2_real = e2_real_lstm[count].unsqueeze(0)
            this_e2_img = e2_img_lstm[count].unsqueeze(0)

            # get known answers for filtered ranking
            ind = index[count]
            this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())]
            this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())]

            all_correct_mentions_e2 = all_known_e2.get(
                (em_map[test_kb.triples[int(ind.item())][0]],
                 rm_map[test_kb.triples[int(ind.item())][1]]), [])
            all_correct_mentions_e1 = all_known_e1.get(
                (em_map[test_kb.triples[int(ind.item())][2]],
                 rm_map[test_kb.triples[int(ind.item())][1]]), [])
            if (args.head_or_tail == "tail"):
                simi = model.complex_score_e1_r_with_all_ementions(
                    this_e1_real,
                    this_e1_img,
                    this_r_real,
                    this_r_img,
                    ementions_real,
                    ementions_img,
                    split=split_dim_for_eval).squeeze(0)
                best_score = simi[this_correct_mentions_e2].max()
                simi[
                    all_correct_mentions_e2] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi.ge(best_score).float()
                equal = simi.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0

            else:
                simi = model.complex_score_e2_r_with_all_ementions(
                    this_e2_real,
                    this_e2_img,
                    this_r_real,
                    this_r_img,
                    ementions_real,
                    ementions_img,
                    split=split_dim_for_eval).squeeze(0)
                best_score = simi[this_correct_mentions_e1].max()
                simi[
                    all_correct_mentions_e1] = -20000000  # MOST NEGATIVE VALUE
                greatereq = simi.ge(best_score).float()
                equal = simi.eq(best_score).float()
                rank = greatereq.sum() + 1 + equal.sum() / 2.0

            if int(ind.item()) in baseline_tail_hits1_indices:
                if rank <= 1:
                    baseline_correct += 1
                continue
            if (rank <= 1):
                #hits1
                hits_1_triple.append([
                    test_kb.triples[int(ind.item())][0],
                    test_kb.triples[int(ind.item())][1],
                    test_kb.triples[int(ind.item())][2]
                ])
                hits_1_evidence.append(injected_rels[int(ind.item())].tolist())
                if (args.head_or_tail == "tail"):
                    # hits_1_correct_answers.append(this_correct_mentions_e2)
                    hits_1_correct_answers.append(
                        [entity_mentions[x] for x in this_correct_mentions_e2])
                else:
                    hits_1_correct_answers.append(
                        [entity_mentions[x] for x in this_correct_mentions_e1])
                hits_1_model_top10.append([])
            # elif(rank>50):
            # 	#nothits50
            # 	nothits_50_triple.append([test_kb.triples[int(ind.item())][0],test_kb.triples[int(ind.item())][1],test_kb.triples[int(ind.item())][2]])
            # 	if(args.head_or_tail=="tail"):
            # 		nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e2])
            # 	else:
            # 		nothits_50_correct_answers.append([entity_mentions[x] for x in this_correct_mentions_e1])
            # 	tmp = simi.sort()[1].tolist()[::-1][:10]
            # 	nothits_50_model_top10.append([entity_mentions[x] for x in tmp])

    indices = list(range(len(hits_1_triple)))
    random.shuffle(indices)
    indices = indices[:args.sample]
    print(baseline_correct)
    for ind in indices:
        print(ind, "|", hits_1_triple[ind], "|", hits_1_correct_answers[ind],
              "|", hits_1_model_top10[ind], "|", hits_1_evidence[ind])
def main(args):
	random.seed(args.seed)
	np.random.seed(args.seed)
	torch.manual_seed(args.seed)
	
	# for batch in train_loader:
	# 	inputs, \
 #        normalizer_loss, \
 #        normalizer_metric, \
 #        labels, \
 #        label_ids, \
 #        filter_mask, \
 #        batch_shared_entities = train_data.input_and_labels_to_device(
 #            batch,
 #            training=True,
 #            device=train_data.device
 #        )
	# 	import pdb
	# 	pdb.set_trace()


	# read token maps
	etokens, etoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","entity_token_id_map.txt"))
	rtokens, rtoken_map = utils.get_tokens_map(os.path.join(args.data_dir,"mapped_to_ids","relation_token_id_map.txt"))
	entity_mentions,em_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","entity_id_map.txt"))
	relation_mentions,rm_map = utils.read_mentions(os.path.join(args.data_dir,"mapped_to_ids","relation_id_map.txt"))

	# create entity_token_indices and entity_lengths
	# [[max length indices for entity 0 ], [max length indices for entity 1], [max length indices for entity 2], ...]
	# [length of entity 0, length of entity 1, length of entity 2, ...]
	entity_token_indices, entity_lengths = utils.get_token_indices_from_mention_indices(entity_mentions, etoken_map, maxlen=args.max_seq_length, use_tqdm=True)
	relation_token_indices, relation_lengths = utils.get_token_indices_from_mention_indices(relation_mentions, rtoken_map, maxlen=args.max_seq_length, use_tqdm=True)

	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

	if not args.do_train and not args.do_eval:
		raise ValueError("At least one of `do_train` or `do_eval` must be True.")

	#train code (+1 for unk token)
	if args.model=="complex":
		if args.separate_lstms:
			model = complexLSTM_2(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, lstm_dropout=args.lstm_dropout)
		else:
			model = complexLSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, lstm_dropout=args.lstm_dropout)
	elif args.model == "rotate":
		model = rotatELSTM(len(etoken_map)+1,len(rtoken_map)+1,args.embedding_dim, initial_token_embedding =args.initial_token_embedding, entity_tokens = etokens, relation_tokens = rtokens, gamma = args.gamma_rotate, lstm_dropout=args.lstm_dropout)
	if(args.do_train):
		data_config = {'input_file': 'train_data_thorough.txt', 'batch_size': args.train_batch_size, 'use_batch_shared_entities': True, 'min_size_batch_labels': args.train_batch_size, 'max_size_prefix_label': 64, 'device': 0}
		expt_settings = {'loss': 'bce', 'replace_entities_by_tokens': True, 'replace_relations_by_tokens': True, 'max_lengths_tuple': [10, 10]}
		train_data = OneToNMentionRelationDataset(dataset_dir=os.path.join(args.data_dir,"mapped_to_ids"), is_training_data=True, **data_config, **expt_settings)
		train_data.create_data_tensors(
			dataset_dir=os.path.join(args.data_dir,"mapped_to_ids"),
			train_input_file='train_data_thorough.txt',
			valid_input_file='validation_data_linked.txt',
			test_input_file='test_data.txt',
		)
		train_loader = train_data.get_loader(
			shuffle=True,
			num_workers=8,
			drop_last=True,
		)
		optimizer = torch.optim.Adagrad(model.parameters(),lr=args.learning_rate,weight_decay=args.weight_decay)

		if(args.resume):
			print("Resuming from:",args.resume)
			checkpoint = torch.load(args.resume)
			model.load_state_dict(checkpoint['state_dict'])
			optimizer.load_state_dict(checkpoint['optimizer'])
			#Load other things too if required

		model.train()
		# if "olpbench" in args.data_dir:
		# 	train_kb = kb(os.path.join(args.data_dir,"train_data_{}.txt".format(args.train_data_type)), em_map = em_map, rm_map = rm_map)
		# 	# train_kb = kb(os.path.join(args.data_dir,"train_data_thorough_r_sorted.txt"), em_map = em_map, rm_map = rm_map)
		# 	# train_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map)

		# else:
		# 	train_kb = kb(os.path.join(args.data_dir,"train.txt"), em_map = em_map, rm_map = rm_map)
		
		# train_data = Dataset(train_kb.triples)
		# train_sampler = RandomSampler(train_data,replacement=False)
		# #train_sampler = SequentialSampler(train_data)
		# train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

		# crossEntropyLoss = torch.nn.CrossEntropyLoss(reduction='mean')
		BCEloss = torch.nn.BCEWithLogitsLoss(reduction='sum')
		for epoch in tqdm(range(0,args.num_train_epochs), desc="epoch"):
			iteration = 0
			for batch in tqdm(train_loader, desc="Train dataloader"):
				inputs, \
				normalizer_loss, \
				normalizer_metric, \
				labels, \
				label_ids, \
				filter_mask, \
				batch_shared_entities = train_data.input_and_labels_to_device(
					batch,
					training=True,
					device="cpu"
				)
				labels = labels.cuda()
				all_outputs = []
				for mode,model_inputs in zip(["head","tail"],inputs):
					if model_inputs==None:
						continue
					# subtract two from author's indices because our map is 2 less
					if mode=="head":
						batch_e2_indices = model_inputs[1] - 2
						batch_r_indices  = model_inputs[0] - 2
						batch_e1_indices = batch_shared_entities - 2
					else:
						batch_e1_indices = model_inputs[0] - 2
						batch_r_indices  = model_inputs[1] - 2
						batch_e2_indices = batch_shared_entities - 2
					# import pdb 
					# pdb.set_trace()

					# convert these indices back into string (compatibility with my code)
					# tik = time.time()
					# batch_e1_strings = convert_mention_index_to_string(batch_e1_indices.squeeze(1), entity_mentions)
					# batch_r_strings  = convert_mention_index_to_string(batch_r_indices.squeeze(1), relation_mentions)
					# batch_e2_strings = convert_mention_index_to_string(batch_e2_indices.squeeze(1), entity_mentions)
					# # print("convert_mention_index_to_string:",time.time() - tik)
					# # do what you used to do now
					# # tik = time.time()
					# train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(batch_e1_strings,etoken_map,maxlen=args.max_seq_length)
					# train_r_mention_tensor, train_r_lengths   = convert_string_to_indices(batch_r_strings,rtoken_map,maxlen=args.max_seq_length)
					# train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(batch_e2_strings,etoken_map,maxlen=args.max_seq_length)
					# print("convert_string_to_indices:",time.time() - tik)
					
					train_e1_mention_tensor, train_e1_lengths = convert_mention_to_token_indices(batch_e1_indices.squeeze(1), entity_token_indices, entity_lengths)
					train_r_mention_tensor, train_r_lengths   = convert_mention_to_token_indices(batch_r_indices.squeeze(1), relation_token_indices, relation_lengths)
					train_e2_mention_tensor, train_e2_lengths = convert_mention_to_token_indices(batch_e2_indices.squeeze(1), entity_token_indices, entity_lengths)


					train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(), train_e1_lengths.cuda()
					train_r_mention_tensor, train_r_lengths   = train_r_mention_tensor.cuda(), train_r_lengths.cuda()
					train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(), train_e2_lengths.cuda()

					# tik = time.time()
					e1_real_lstm, e1_img_lstm = model.get_mention_embedding(train_e1_mention_tensor,0,train_e1_lengths)
					r_real_lstm, r_img_lstm   = model.get_mention_embedding(train_r_mention_tensor,1,train_r_lengths)
					e2_real_lstm, e2_img_lstm = model.get_mention_embedding(train_e2_mention_tensor,0,train_e2_lengths)
					# print("get_mention_embedding:",time.time() - tik)

					# tik = time.time()
					if mode=="head":
						output = model.complex_score_e2_r_with_all_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,e1_real_lstm,e1_img_lstm)
					else:
						output = model.complex_score_e1_r_with_all_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,e2_real_lstm,e2_img_lstm)
					# print("model_scoring:",time.time() - tik)

					all_outputs.append(output)
				
				all_outputs = torch.cat(all_outputs)
				loss = BCEloss(all_outputs.view(-1),labels.view(-1))
				# loss = loss.sum()
				loss /= normalizer_loss
				optimizer.zero_grad()
				loss.backward()
				optimizer.step()
				if(iteration%args.print_loss_every==0):
					print("Current loss:",loss.item())
				iteration+=1
			if(epoch%args.save_model_every==0):
				utils.save_checkpoint({
						'state_dict':model.state_dict(),
						'optimizer':optimizer.state_dict()
						},args.output_dir+"/checkpoint_epoch_{}".format(epoch+1))

		# for epoch in tqdm(range(0,args.num_train_epochs),desc="epoch"):
		# 	iteration = 0
		# 	for train_e1_batch, train_r_batch, train_e2_batch in tqdm(train_dataloader,desc="Train dataloader"):
		# 		# skip this batch
		# 		if(random.random()<args.skip_train_prob):
		# 			continue	
		# 		batch_size = len(train_e1_batch)
		# 		train_e1_mention_tensor, train_e1_lengths = convert_string_to_indices(train_e1_batch,etoken_map,maxlen=args.max_seq_length)
		# 		train_r_mention_tensor, train_r_lengths = convert_string_to_indices(train_r_batch,rtoken_map,maxlen=args.max_seq_length)
		# 		train_e2_mention_tensor, train_e2_lengths = convert_string_to_indices(train_e2_batch,etoken_map,maxlen=args.max_seq_length)

		# 		train_e1_mention_tensor, train_e1_lengths = train_e1_mention_tensor.cuda(), train_e1_lengths.cuda()
		# 		train_r_mention_tensor, train_r_lengths = train_r_mention_tensor.cuda(), train_r_lengths.cuda()
		# 		train_e2_mention_tensor, train_e2_lengths = train_e2_mention_tensor.cuda(), train_e2_lengths.cuda()


		# 		e1_real_lstm, e1_img_lstm = model.get_mention_embedding(train_e1_mention_tensor,0,train_e1_lengths)
		# 		r_real_lstm, r_img_lstm = model.get_mention_embedding(train_r_mention_tensor,1,train_r_lengths)
		# 		e2_real_lstm, e2_img_lstm = model.get_mention_embedding(train_e2_mention_tensor,0,train_e2_lengths)
		# 		#tail
		# 		simi_t = model.complex_score_e1_r_with_all_ementions(e1_real_lstm,e1_img_lstm,r_real_lstm,r_img_lstm,e2_real_lstm,e2_img_lstm)
		# 		#head
		# 		simi_h = model.complex_score_e2_r_with_all_ementions(e2_real_lstm,e2_img_lstm,r_real_lstm,r_img_lstm,e1_real_lstm,e1_img_lstm)
		# 		# change the loss suitably
		# 		target = torch.eye(batch_size).cuda()
		# 		# import pdb
		# 		# pdb.set_trace()
		# 		loss_t = BCEloss(simi_t.view(-1),target.view(-1))
		# 		loss_h = BCEloss(simi_h.view(-1),target.view(-1))
		# 		loss = (loss_h+loss_t)/2
		# 		loss /= target.size(0) * target.size(1)

		# 		# Do the routine
		# 		optimizer.zero_grad()
		# 		loss.backward()
		# 		#gradient clip?
		# 		optimizer.step()

		# 		if(iteration%args.print_loss_every==0):
		# 			print("Current loss(avg, tail, head):",loss.item(), loss_t.item(), loss_h.item())
		# 		iteration+=1
		# 	if(epoch%args.save_model_every==0 and epoch!=0):
		# 		utils.save_checkpoint({
		# 				'state_dict':model.state_dict(),
		# 				'optimizer':optimizer.state_dict()
		# 				},args.output_dir+"/checkpoint_epoch_{}".format(epoch+1))


	if args.do_eval:
		#eval code
		metrics = {}
		metrics['mr'] = 0
		metrics['mrr'] = 0
		metrics['hits1'] = 0
		metrics['hits10'] = 0
		metrics['hits50'] = 0
		metrics['mr_t'] = 0
		metrics['mrr_t'] = 0
		metrics['hits1_t'] = 0
		metrics['hits10_t'] = 0
		metrics['hits50_t'] = 0
		metrics['mr_h'] = 0
		metrics['mrr_h'] = 0
		metrics['hits1_h'] = 0
		metrics['hits10_h'] = 0
		metrics['hits50_h'] = 0

		if(args.resume and not args.do_train):
			print("Resuming from:",args.resume)
			checkpoint = torch.load(args.resume,map_location=lambda storage, loc: storage)
			model.load_state_dict(checkpoint['state_dict'])

		model.eval()

		# get embeddings for all entity mentions
		entity_mentions_tensor, entity_mentions_lengths = convert_string_to_indices(entity_mentions,etoken_map,maxlen=args.max_seq_length,use_tqdm=True)
		entity_mentions_tensor = entity_mentions_tensor.cuda()
		entity_mentions_lengths = entity_mentions_lengths.cuda()

		ementions_real_lis = []
		ementions_img_lis = []
		split = 100 #cant fit all in gpu together. hence split
		with torch.no_grad():
			for i in tqdm(range(0,len(entity_mentions_tensor),len(entity_mentions_tensor)//split)):
				data = entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:]
				data_lengths = entity_mentions_lengths[i:i+len(entity_mentions_tensor)//split]
				ementions_real_lstm,ementions_img_lstm = model.get_mention_embedding(data,0,data_lengths)			
				# a = model.Et_im(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:])
				# b = model.Et_re(entity_mentions_tensor[i:i+len(entity_mentions_tensor)//split,:])
				
				# a_lstm,_ = model.lstm(a)
				# a_lstm = a_lstm[:,-1,:]

				
				# b_lstm,_ = model.lstm(b)
				# b_lstm = b_lstm[:,-1,:]

				ementions_real_lis.append(ementions_real_lstm.cpu())
				ementions_img_lis.append(ementions_img_lstm.cpu())
		del entity_mentions_tensor,ementions_real_lstm,ementions_img_lstm
		torch.cuda.empty_cache()
		ementions_real = torch.cat(ementions_real_lis).cuda()
		ementions_img = torch.cat(ementions_img_lis).cuda()
		########################################################################
		if "olpbench" in args.data_dir:
			# test_kb = kb(os.path.join(args.data_dir,"test_data_sophis.txt"), em_map = em_map, rm_map = rm_map)
			test_kb = kb(os.path.join(args.data_dir,"test_data.txt"), em_map = em_map, rm_map = rm_map)
		else:
			test_kb = kb(os.path.join(args.data_dir,"test.txt"), em_map = em_map, rm_map = rm_map)

		print("Loading all_known pickled data...(takes times since large)")
		all_known_e2 = {}
		all_known_e1 = {}
		all_known_e2,all_known_e1 = pickle.load(open(os.path.join(args.data_dir,"all_knowns_{}_linked.pkl".format(args.train_data_type)),"rb"))


		test_e1_tokens_tensor, test_e1_tokens_lengths = convert_string_to_indices(test_kb.triples[:,0], etoken_map,maxlen=args.max_seq_length)
		test_r_tokens_tensor, test_r_tokens_lengths = convert_string_to_indices(test_kb.triples[:,1], rtoken_map,maxlen=args.max_seq_length)
		test_e2_tokens_tensor, test_e2_tokens_lengths = convert_string_to_indices(test_kb.triples[:,2], etoken_map,maxlen=args.max_seq_length)
		
		# e2_tensor = convert_string_to_indices(test_kb.triples[:,2], etoken_map)
		indices = torch.Tensor(range(len(test_kb.triples))) #indices would be used to fetch alternative answers while evaluating
		test_data = TensorDataset(indices, test_e1_tokens_tensor, test_r_tokens_tensor, test_e2_tokens_tensor, test_e1_tokens_lengths, test_r_tokens_lengths, test_e2_tokens_lengths)
		test_sampler = SequentialSampler(test_data)
		test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size)
		split_dim_for_eval = 1
		if(args.embedding_dim>=256 and "olpbench" in args.data_dir and "rotat" in args.model):
			split_dim_for_eval = 4
		if(args.embedding_dim>=512 and "olpbench" in args.data_dir):
			split_dim_for_eval = 4
		if(args.embedding_dim>=512 and "olpbench" in args.data_dir and "rotat" in args.model):
			split_dim_for_eval = 6
		split_dim_for_eval = 1
		for index, test_e1_tokens, test_r_tokens, test_e2_tokens, test_e1_lengths, test_r_lengths, test_e2_lengths in tqdm(test_dataloader,desc="Test dataloader"):
			print(metrics)
			test_e1_tokens, test_e1_lengths = test_e1_tokens.to(device), test_e1_lengths.to(device)
			test_r_tokens, test_r_lengths = test_r_tokens.to(device), test_r_lengths.to(device)
			test_e2_tokens, test_e2_lengths = test_e2_tokens.to(device), test_e2_lengths.to(device)
			with torch.no_grad():
				e1_real_lstm, e1_img_lstm = model.get_mention_embedding(test_e1_tokens,0, test_e1_lengths)
				r_real_lstm, r_img_lstm = model.get_mention_embedding(test_r_tokens,1, test_r_lengths)	
				e2_real_lstm, e2_img_lstm = model.get_mention_embedding(test_e2_tokens,0, test_e2_lengths)


			for count in tqdm(range(index.shape[0]), desc="Evaluating"):
				# breakpoint()
				this_e1_real = e1_real_lstm[count].unsqueeze(0)
				this_e1_img  = e1_img_lstm[count].unsqueeze(0)
				this_r_real  = r_real_lstm[count].unsqueeze(0)
				this_r_img   = r_img_lstm[count].unsqueeze(0)
				this_e2_real = e2_real_lstm[count].unsqueeze(0)
				this_e2_img  = e2_img_lstm[count].unsqueeze(0)
				simi_t = model.complex_score_e1_r_with_all_ementions(this_e1_real,this_e1_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0)
				simi_h = model.complex_score_e2_r_with_all_ementions(this_e2_real,this_e2_img,this_r_real,this_r_img,ementions_real,ementions_img,split=split_dim_for_eval).squeeze(0)
				# get known answers for filtered ranking
				ind = index[count]
				this_correct_mentions_e2 = test_kb.e2_all_answers[int(ind.item())]
				this_correct_mentions_e1 = test_kb.e1_all_answers[int(ind.item())] 

				all_correct_mentions_e2 = all_known_e2.get((em_map[test_kb.triples[int(ind.item())][0]],rm_map[test_kb.triples[int(ind.item())][1]]),[])
				all_correct_mentions_e1 = all_known_e1.get((em_map[test_kb.triples[int(ind.item())][2]],rm_map[test_kb.triples[int(ind.item())][1]]),[])
				
				# compute metrics
				best_score = simi_t[this_correct_mentions_e2].max()
				simi_t[all_correct_mentions_e2] = -20000000 # MOST NEGATIVE VALUE
				greatereq = simi_t.ge(best_score).float()
				equal = simi_t.eq(best_score).float()
				rank = greatereq.sum()+1+equal.sum()/2.0

				metrics['mr_t'] += rank
				metrics['mrr_t'] += 1.0/rank
				metrics['hits1_t'] += rank.le(1).float()
				metrics['hits10_t'] += rank.le(10).float()
				metrics['hits50_t'] += rank.le(50).float()

				best_score = simi_h[this_correct_mentions_e1].max()
				simi_h[all_correct_mentions_e1] = -20000000 # MOST NEGATIVE VALUE
				greatereq = simi_h.ge(best_score).float()
				equal = simi_h.eq(best_score).float()
				rank = greatereq.sum()+1+equal.sum()/2.0
				metrics['mr_h'] += rank
				metrics['mrr_h'] += 1.0/rank
				metrics['hits1_h'] += rank.le(1).float()
				metrics['hits10_h'] += rank.le(10).float()
				metrics['hits50_h'] += rank.le(50).float()

				metrics['mr'] = (metrics['mr_h']+metrics['mr_t'])/2
				metrics['mrr'] = (metrics['mrr_h']+metrics['mrr_t'])/2
				metrics['hits1'] = (metrics['hits1_h']+metrics['hits1_t'])/2
				metrics['hits10'] = (metrics['hits10_h']+metrics['hits10_t'])/2
				metrics['hits50'] = (metrics['hits50_h']+metrics['hits50_t'])/2

		for key in metrics:
			metrics[key] = metrics[key] / len(test_kb.triples)
		print(metrics)