Exemple #1
0
    def minibatch2input(self, batch, predict=False, topk=None):
        if topk == None:
            topk = 10000
        topk = min(topk, len(batch[0]['selected_cands']['cands']))

        n_ments = len(batch)
        # only uisng negative samples when the document doesn't have any supervision (i.e. not CoNLL)
        tps = [m['selected_cands']['true_pos'] >= 0 for m in batch]
        if not predict and (self.args.multi_instance
                            or self.args.semisup) and not np.any(tps):
            n_negs = self.args.n_negs
        else:
            n_negs = 0

        # convert data items to pytorch inputs
        token_ids = [
            m['context'][0] +
            m['context'][1] if len(m['context'][0]) + len(m['context'][1]) > 0
            else [self.model.word_voca.unk_id] for m in batch
        ]
        s_ltoken_ids = [m['snd_ctx'][0] for m in batch]
        s_rtoken_ids = [m['snd_ctx'][1] for m in batch]
        s_mtoken_ids = [m['snd_ment'] for m in batch]

        entity_ids = torch.LongTensor(
            [m['selected_cands']['cands'][:topk] for m in batch])
        p_e_m = torch.FloatTensor(
            [m['selected_cands']['p_e_m'][:topk] for m in batch])
        entity_mask = torch.FloatTensor(
            [m['selected_cands']['mask'][:topk] for m in batch])
        true_pos = torch.LongTensor([
            m['selected_cands']['true_pos']
            if m['selected_cands']['true_pos'] < topk else -1 for m in batch
        ])
        p_e_ent_net = torch.FloatTensor([
            m['selected_cands']['p_e_ent_net'][:topk] for m in batch
        ]) if len(batch) > 1 else torch.zeros(1, entity_ids.shape[1])

        if n_negs > 0:
            # add n_negs negative samples at the beginning of lists
            def ent_neg_sample(neg_cands_p_e_m, exclusive):
                sample_ids = np.random.choice(len(neg_cands_p_e_m),
                                              n_negs * 10)
                all_samples = list(
                    zip(
                        np.array([s[0] for s in neg_cands_p_e_m
                                  ])[sample_ids].astype(int),
                        np.array([s[1] for s in neg_cands_p_e_m])[sample_ids]))
                exclusive = set(exclusive)
                samples = []
                for s in all_samples:
                    if s[0] not in exclusive:
                        samples.append(s)

                if len(samples) < n_negs:
                    samples = samples + [(self.model.entity_voca.unk_id, 1e-3)
                                         ] * (n_negs - len(samples))
                else:
                    shuffle(samples)
                    samples = samples[:n_negs]
                return np.array([s[0] for s in samples
                                 ]), np.array([s[1] for s in samples])

            neg_cands_p_e_m = [list(zip(list(m['cands']), list(m['p_e_m']))) + \
                               (list(zip(list(m['neg_cands'], [1e-3] * len(m['neg_cands'])))) if len(m['cands']) <= topk else [])
                               for m in batch]
            neg_cands_p_e_m = [
                ent_neg_sample(si, entity_ids_i) for si, entity_ids_i in zip(
                    neg_cands_p_e_m, entity_ids.numpy())
            ]
            neg_entity_ids = torch.Tensor(
                [si[0].astype(float) for si in neg_cands_p_e_m]).long()
            neg_p_e_m = torch.Tensor(
                [si[1].astype(float) for si in neg_cands_p_e_m])

            neg_entity_mask = torch.ones(n_ments, n_negs)
            entity_ids = torch.cat([neg_entity_ids, entity_ids], dim=1)
            entity_mask = torch.cat([neg_entity_mask, entity_mask], dim=1)
            p_e_m = torch.cat([neg_p_e_m, p_e_m], dim=1)
            true_pos = true_pos.add_(n_negs)

        entity_ids = Variable(entity_ids.cuda())
        true_pos = Variable(true_pos.cuda())
        p_e_m = Variable(p_e_m.cuda())
        p_e_ent_net = Variable(p_e_ent_net.cuda())
        entity_mask = Variable(entity_mask.cuda())

        token_ids, token_mask = utils.make_equal_len(
            token_ids, self.model.word_voca.unk_id)
        s_ltoken_ids, s_ltoken_mask = utils.make_equal_len(
            s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False)
        s_rtoken_ids, s_rtoken_mask = utils.make_equal_len(
            s_rtoken_ids, self.model.snd_word_voca.unk_id)
        s_rtoken_ids = [l[::-1] for l in s_rtoken_ids]
        s_rtoken_mask = [l[::-1] for l in s_rtoken_mask]
        s_mtoken_ids, s_mtoken_mask = utils.make_equal_len(
            s_mtoken_ids, self.model.snd_word_voca.unk_id)

        token_ids = Variable(torch.LongTensor(token_ids).cuda())
        token_mask = Variable(torch.FloatTensor(token_mask).cuda())

        s_ltoken_ids = Variable(torch.LongTensor(s_ltoken_ids).cuda())
        s_ltoken_mask = Variable(torch.FloatTensor(s_ltoken_mask).cuda())
        s_rtoken_ids = Variable(torch.LongTensor(s_rtoken_ids).cuda())
        s_rtoken_mask = Variable(torch.FloatTensor(s_rtoken_mask).cuda())
        s_mtoken_ids = Variable(torch.LongTensor(s_mtoken_ids).cuda())
        s_mtoken_mask = Variable(torch.FloatTensor(s_mtoken_mask).cuda())

        ret = {
            'token_ids': token_ids,
            'token_mask': token_mask,
            'entity_ids': entity_ids,
            'entity_mask': entity_mask,
            'p_e_m': p_e_m,
            'p_e_ent_net': p_e_ent_net,
            'true_pos': true_pos,
            's_ltoken_ids': s_ltoken_ids,
            's_ltoken_mask': s_ltoken_mask,
            's_rtoken_ids': s_rtoken_ids,
            's_rtoken_mask': s_rtoken_mask,
            's_mtoken_ids': s_mtoken_ids,
            's_mtoken_mask': s_mtoken_mask,
            'n_negs': n_negs
        }
        return ret
Exemple #2
0
    def predict(self, data):
        predictions = {items[0]['doc_name']: [] for items in data}
        self.model.eval()

        for batch in data:  # each document is a minibatch
            token_ids = [
                m['context'][0] + m['context'][1] if len(m['context'][0]) +
                len(m['context'][1]) > 0 else [self.model.word_voca.unk_id]
                for m in batch
            ]
            s_ltoken_ids = [m['snd_ctx'][0] for m in batch]
            s_rtoken_ids = [m['snd_ctx'][1] for m in batch]
            s_mtoken_ids = [m['snd_ment'] for m in batch]

            lctx_ids = s_ltoken_ids
            rctx_ids = s_rtoken_ids
            m_ids = s_mtoken_ids

            entity_ids = Variable(
                torch.LongTensor([m['selected_cands']['cands']
                                  for m in batch]).cuda())
            p_e_m = Variable(
                torch.FloatTensor(
                    [m['selected_cands']['p_e_m'] for m in batch]).cuda())
            entity_mask = Variable(
                torch.FloatTensor([m['selected_cands']['mask']
                                   for m in batch]).cuda())
            true_pos = Variable(
                torch.LongTensor(
                    [m['selected_cands']['true_pos'] for m in batch]).cuda())

            token_ids, token_mask = utils.make_equal_len(
                token_ids, self.model.word_voca.unk_id)
            s_ltoken_ids, s_ltoken_mask = utils.make_equal_len(
                s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False)
            s_rtoken_ids, s_rtoken_mask = utils.make_equal_len(
                s_rtoken_ids, self.model.snd_word_voca.unk_id)
            s_rtoken_ids = [l[::-1] for l in s_rtoken_ids]
            s_rtoken_mask = [l[::-1] for l in s_rtoken_mask]
            s_mtoken_ids, s_mtoken_mask = utils.make_equal_len(
                s_mtoken_ids, self.model.snd_word_voca.unk_id)

            token_ids = Variable(torch.LongTensor(token_ids).cuda())
            token_mask = Variable(torch.FloatTensor(token_mask).cuda())
            # too ugly, but too lazy to fix it
            self.model.s_ltoken_ids = Variable(
                torch.LongTensor(s_ltoken_ids).cuda())
            self.model.s_ltoken_mask = Variable(
                torch.FloatTensor(s_ltoken_mask).cuda())
            self.model.s_rtoken_ids = Variable(
                torch.LongTensor(s_rtoken_ids).cuda())
            self.model.s_rtoken_mask = Variable(
                torch.FloatTensor(s_rtoken_mask).cuda())
            self.model.s_mtoken_ids = Variable(
                torch.LongTensor(s_mtoken_ids).cuda())
            self.model.s_mtoken_mask = Variable(
                torch.FloatTensor(s_mtoken_mask).cuda())

            scores = self.model.forward(token_ids,
                                        token_mask,
                                        entity_ids,
                                        entity_mask,
                                        p_e_m,
                                        gold=true_pos.view(-1, 1))
            scores = scores.cpu().data.numpy()

            # print out relation weights
            if self.args.mode == 'eval' and self.args.print_rel:
                print('================================')
                weights = self.model._rel_ctx_ctx_weights.cpu().data.numpy()
                voca = self.model.snd_word_voca
                for i in range(len(batch)):
                    print(
                        ' '.join([voca.id2word[id] for id in lctx_ids[i]]),
                        utils.tokgreen(' '.join(
                            [voca.id2word[id] for id in m_ids[i]])),
                        ' '.join([voca.id2word[id] for id in rctx_ids[i]]))
                    for j in range(len(batch)):
                        if i == j:
                            continue
                        np.set_printoptions(precision=2)
                        print(
                            '\t', weights[:, i, j], '\t',
                            ' '.join([voca.id2word[id] for id in lctx_ids[j]]),
                            utils.tokgreen(' '.join(
                                [voca.id2word[id] for id in m_ids[j]])),
                            ' '.join([voca.id2word[id] for id in rctx_ids[j]]))

            pred_ids = np.argmax(scores, axis=1)
            pred_entities = [
                m['selected_cands']['named_cands'][i]
                if m['selected_cands']['mask'][i] == 1 else
                (m['selected_cands']['named_cands'][0]
                 if m['selected_cands']['mask'][0] == 1 else 'NI')
                for (i, m) in zip(pred_ids, batch)
            ]
            #            print(pred_entities)
            doc_names = [m['doc_name'] for m in batch]

            if self.args.mode == 'eval' and self.args.print_incorrect:
                gold = [
                    item['selected_cands']['named_cands'][
                        item['selected_cands']['true_pos']]
                    if item['selected_cands']['true_pos'] >= 0 else 'UNKNOWN'
                    for item in batch
                ]
                pred = pred_entities
                for i in range(len(gold)):
                    if gold[i] != pred[i]:
                        print('--------------------------------------------')
                        #                        pprint(batch[i]['raw'])
                        print(gold[i], pred[i])
                        print(pred_ids[i], scores[i])

            for dname, entity in zip(doc_names, pred_entities):
                predictions[dname].append({'pred': (entity, 0.)})

        return predictions
Exemple #3
0
    def predict(self, data):
        predictions = {items[0]['doc_name']: [] for items in data}
        self.model.eval()

        for batch in data:  # each document is a minibatch
            token_ids = [
                m['context'][0] + m['context'][1] if len(m['context'][0]) +
                len(m['context'][1]) > 0 else [self.model.word_voca.unk_id]
                for m in batch
            ]
            s_ltoken_ids = [m['snd_ctx'][0] for m in batch]
            s_rtoken_ids = [m['snd_ctx'][1] for m in batch]
            s_mtoken_ids = [m['snd_ment'] for m in batch]

            lctx_ids = s_ltoken_ids
            rctx_ids = s_rtoken_ids
            m_ids = s_mtoken_ids

            entity_ids = Variable(
                torch.LongTensor([m['selected_cands']['cands']
                                  for m in batch]).to(device))
            p_e_m = Variable(
                torch.FloatTensor(
                    [m['selected_cands']['p_e_m'] for m in batch]).to(device))
            entity_mask = Variable(
                torch.FloatTensor([m['selected_cands']['mask']
                                   for m in batch]).to(device))
            true_pos = Variable(
                torch.LongTensor([
                    m['selected_cands']['true_pos'] for m in batch
                ]).to(device))

            token_ids, token_mask = utils.make_equal_len(
                token_ids, self.model.word_voca.unk_id)
            s_ltoken_ids, s_ltoken_mask = utils.make_equal_len(
                s_ltoken_ids, self.model.snd_word_voca.unk_id, to_right=False)
            s_rtoken_ids, s_rtoken_mask = utils.make_equal_len(
                s_rtoken_ids, self.model.snd_word_voca.unk_id)
            s_rtoken_ids = [l[::-1] for l in s_rtoken_ids]
            s_rtoken_mask = [l[::-1] for l in s_rtoken_mask]
            s_mtoken_ids, s_mtoken_mask = utils.make_equal_len(
                s_mtoken_ids, self.model.snd_word_voca.unk_id)

            token_ids = Variable(torch.LongTensor(token_ids).to(device))
            token_mask = Variable(torch.FloatTensor(token_mask).to(device))
            # too ugly, but too lazy to fix it
            self.model.s_ltoken_ids = Variable(
                torch.LongTensor(s_ltoken_ids).to(device))
            self.model.s_ltoken_mask = Variable(
                torch.FloatTensor(s_ltoken_mask).to(device))
            self.model.s_rtoken_ids = Variable(
                torch.LongTensor(s_rtoken_ids).to(device))
            self.model.s_rtoken_mask = Variable(
                torch.FloatTensor(s_rtoken_mask).to(device))
            self.model.s_mtoken_ids = Variable(
                torch.LongTensor(s_mtoken_ids).to(device))
            self.model.s_mtoken_mask = Variable(
                torch.FloatTensor(s_mtoken_mask).to(device))

            scores = self.model.forward(token_ids,
                                        token_mask,
                                        entity_ids,
                                        entity_mask,
                                        p_e_m,
                                        gold=true_pos.view(-1, 1))
            scores = scores.cpu().data.numpy()

            # print out relation weights
            if self.args.mode == 'eval' and self.args.print_rel:
                print('================================')
                weights = self.model._rel_ctx_ctx_weights.cpu().data.numpy()
                voca = self.model.snd_word_voca
                for i in range(len(batch)):
                    print(
                        ' '.join([voca.id2word[id] for id in lctx_ids[i]]),
                        utils.tokgreen(' '.join(
                            [voca.id2word[id] for id in m_ids[i]])),
                        ' '.join([voca.id2word[id] for id in rctx_ids[i]]))
                    for j in range(len(batch)):
                        if i == j:
                            continue
                        np.set_printoptions(precision=2)
                        print(
                            '\t', weights[:, i, j], '\t',
                            ' '.join([voca.id2word[id] for id in lctx_ids[j]]),
                            utils.tokgreen(' '.join(
                                [voca.id2word[id] for id in m_ids[j]])),
                            ' '.join([voca.id2word[id] for id in rctx_ids[j]]))

            pred_ids = np.argmax(scores, axis=1)
            processed_scores = []
            for i, score in enumerate(scores):
                min_score = np.min(score)
                if all(list(map(lambda x: x == min_score, score))):
                    processed_scores.append(scores[i].tolist())
                    continue
                count = 0
                for j, s in enumerate(score):
                    if s == min_score:
                        scores[i][j] = np.NINF
                        count += 1

                non_inf_list = list(
                    filter(lambda x: x != float('-inf'), scores[i]))
                mean = np.mean(non_inf_list)
                processed_scores.append(
                    ((scores[i] - mean) * len(non_inf_list) * 10).tolist())

            e_x = np.exp(processed_scores -
                         np.amax(processed_scores, axis=1, keepdims=True))
            softmax_scores = e_x / e_x.sum(axis=1, keepdims=True)

            pred_scores = np.amax(np.array(processed_scores), axis=1)
            pred_confidences = np.amax(softmax_scores, axis=1)
            pred_entities = [
                m['selected_cands']['named_cands'][i]
                if m['selected_cands']['mask'][i] == 1 else
                (m['selected_cands']['named_cands'][0]
                 if m['selected_cands']['mask'][0] == 1 else 'NIL')
                for (i, m) in zip(pred_ids, batch)
            ]
            doc_names = [m['doc_name'] for m in batch]

            if self.args.mode == 'eval' and self.args.print_incorrect:
                gold = [
                    item['selected_cands']['named_cands'][
                        item['selected_cands']['true_pos']]
                    if item['selected_cands']['true_pos'] >= 0 else 'UNKNOWN'
                    for item in batch
                ]
                pred = pred_entities
                for i in range(len(gold)):
                    if gold[i] != pred[i]:
                        print('--------------------------------------------')
                        pprint(batch[i]['raw'])
                        print(gold[i], pred[i])

            for dname, entity, pred_score, pred_confidence in zip(
                    doc_names, pred_entities, pred_scores, pred_confidences):
                predictions[dname].append({
                    'pred': (entity, 0.),
                    'score': pred_score,
                    'confidence': pred_confidence
                })
        return predictions
Exemple #4
0
    def train(self, org_train_dataset, org_dev_datasets, config):
        print('extracting training data')
        train_dataset = self.get_data_items(org_train_dataset, predict=False)
        print('#train docs', len(train_dataset))

        dev_datasets = []
        for dname, data in org_dev_datasets:
            dev_datasets.append((dname, self.get_data_items(data,
                                                            predict=True)))
            print(dname, '#dev docs', len(dev_datasets[-1][1]))

        print('creating optimizer')
        optimizer = optim.Adam(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=config['lr'])
        best_f1 = -1
        not_better_count = 0
        is_counting = False
        eval_after_n_epochs = self.args.eval_after_n_epochs

        for e in range(config['n_epochs']):
            shuffle(train_dataset)

            total_loss = 0
            for dc, batch in enumerate(
                    train_dataset):  # each document is a minibatch
                self.model.train()
                optimizer.zero_grad()

                # convert data items to pytorch inputs
                token_ids = [
                    m['context'][0] + m['context'][1]
                    if len(m['context'][0]) + len(m['context'][1]) > 0 else
                    [self.model.word_voca.unk_id] for m in batch
                ]
                s_ltoken_ids = [m['snd_ctx'][0] for m in batch]
                s_rtoken_ids = [m['snd_ctx'][1] for m in batch]
                s_mtoken_ids = [m['snd_ment'] for m in batch]

                entity_ids = Variable(
                    torch.LongTensor(
                        [m['selected_cands']['cands'] for m in batch]).cuda())
                true_pos = Variable(
                    torch.LongTensor([
                        m['selected_cands']['true_pos'] for m in batch
                    ]).cuda())
                p_e_m = Variable(
                    torch.FloatTensor(
                        [m['selected_cands']['p_e_m'] for m in batch]).cuda())
                entity_mask = Variable(
                    torch.FloatTensor(
                        [m['selected_cands']['mask'] for m in batch]).cuda())

                token_ids, token_mask = utils.make_equal_len(
                    token_ids, self.model.word_voca.unk_id)
                s_ltoken_ids, s_ltoken_mask = utils.make_equal_len(
                    s_ltoken_ids,
                    self.model.snd_word_voca.unk_id,
                    to_right=False)
                s_rtoken_ids, s_rtoken_mask = utils.make_equal_len(
                    s_rtoken_ids, self.model.snd_word_voca.unk_id)
                s_rtoken_ids = [l[::-1] for l in s_rtoken_ids]
                s_rtoken_mask = [l[::-1] for l in s_rtoken_mask]
                s_mtoken_ids, s_mtoken_mask = utils.make_equal_len(
                    s_mtoken_ids, self.model.snd_word_voca.unk_id)

                token_ids = Variable(torch.LongTensor(token_ids).cuda())
                token_mask = Variable(torch.FloatTensor(token_mask).cuda())
                # too ugly but too lazy to fix it
                self.model.s_ltoken_ids = Variable(
                    torch.LongTensor(s_ltoken_ids).cuda())
                self.model.s_ltoken_mask = Variable(
                    torch.FloatTensor(s_ltoken_mask).cuda())
                self.model.s_rtoken_ids = Variable(
                    torch.LongTensor(s_rtoken_ids).cuda())
                self.model.s_rtoken_mask = Variable(
                    torch.FloatTensor(s_rtoken_mask).cuda())
                self.model.s_mtoken_ids = Variable(
                    torch.LongTensor(s_mtoken_ids).cuda())
                self.model.s_mtoken_mask = Variable(
                    torch.FloatTensor(s_mtoken_mask).cuda())

                scores = self.model.forward(token_ids,
                                            token_mask,
                                            entity_ids,
                                            entity_mask,
                                            p_e_m,
                                            gold=true_pos.view(-1, 1))
                loss = self.model.loss(scores, true_pos)

                loss.backward()
                optimizer.step()
                #                self.model.regularize(max_norm=100)

                loss = loss.cpu().data.numpy()
                total_loss += loss
                print('epoch',
                      e,
                      "%0.2f%%" % (dc / len(train_dataset) * 100),
                      loss,
                      end='\r')

            print('epoch', e, 'total loss', total_loss,
                  total_loss / len(train_dataset))

            if (e + 1) % eval_after_n_epochs == 0:
                dev_f1 = 0
                for di, (dname, data) in enumerate(dev_datasets):
                    predictions = self.predict(data)
                    f1 = D.eval(org_dev_datasets[di][1], predictions)
                    print(dname, utils.tokgreen('micro F1: ' + str(f1)))

                    if dname == 'aida-A':
                        dev_f1 = f1

                if config[
                        'lr'] == 1e-4 and dev_f1 >= self.args.dev_f1_change_lr:
                    eval_after_n_epochs = 2
                    is_counting = True
                    best_f1 = dev_f1
                    not_better_count = 0

                    config['lr'] = 1e-5
                    print('change learning rate to', config['lr'])
                    if self.args.mulrel_type == 'rel-norm':
                        optimizer = optim.Adam([
                            p
                            for p in self.model.parameters() if p.requires_grad
                        ],
                                               lr=config['lr'])
                    elif self.args.mulrel_type == 'ment-norm':
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = config['lr']

                if is_counting:
                    if dev_f1 < best_f1:
                        not_better_count += 1
                    else:
                        not_better_count = 0
                        best_f1 = dev_f1
                        print('save model to', self.args.model_path)
                        self.model.save(self.args.model_path)

                if not_better_count == self.args.n_not_inc:
                    break

                self.model.print_weight_norm()
Exemple #5
0
    def predict(self, data, topk=1):
        predictions = {items[0]['doc_name']: [] for items in data}
        self.model.eval()

        for batch in data:  # each document is a minibatch
            token_ids = [m['context'][0] + m['context'][1]
                         if len(m['context'][0]) + len(m['context'][1]) > 0
                         else [self.model.word_voca.unk_id]
                         for m in batch]
            s_ltoken_ids = [m['snd_ctx'][0] for m in batch]
            s_rtoken_ids = [m['snd_ctx'][1] for m in batch]
            s_mtoken_ids = [m['snd_ment'] for m in batch]

            lctx_ids = s_ltoken_ids
            rctx_ids = s_rtoken_ids
            m_ids = s_mtoken_ids

            entity_ids = Variable(torch.LongTensor([m['selected_cands']['cands'] for m in batch]).cuda())
            p_e_m = Variable(torch.FloatTensor([m['selected_cands']['p_e_m'] for m in batch]).cuda())
            entity_mask = Variable(torch.FloatTensor([m['selected_cands']['mask'] for m in batch]).cuda())
            true_pos = Variable(torch.LongTensor([m['selected_cands']['true_pos'] for m in batch]).cuda())

            token_ids, token_mask = utils.make_equal_len(token_ids, self.model.word_voca.unk_id)
            s_ltoken_ids, s_ltoken_mask = utils.make_equal_len(s_ltoken_ids, self.model.snd_word_voca.unk_id,
                                                               to_right=False)
            s_rtoken_ids, s_rtoken_mask = utils.make_equal_len(s_rtoken_ids, self.model.snd_word_voca.unk_id)
            s_rtoken_ids = [l[::-1] for l in s_rtoken_ids]
            s_rtoken_mask = [l[::-1] for l in s_rtoken_mask]
            s_mtoken_ids, s_mtoken_mask = utils.make_equal_len(s_mtoken_ids, self.model.snd_word_voca.unk_id)

            token_ids = Variable(torch.LongTensor(token_ids).cuda())
            token_mask = Variable(torch.FloatTensor(token_mask).cuda())
            # too ugly, but too lazy to fix it
            self.model.s_ltoken_ids = Variable(torch.LongTensor(s_ltoken_ids).cuda())
            self.model.s_ltoken_mask = Variable(torch.FloatTensor(s_ltoken_mask).cuda())
            self.model.s_rtoken_ids = Variable(torch.LongTensor(s_rtoken_ids).cuda())
            self.model.s_rtoken_mask = Variable(torch.FloatTensor(s_rtoken_mask).cuda())
            self.model.s_mtoken_ids = Variable(torch.LongTensor(s_mtoken_ids).cuda())
            self.model.s_mtoken_mask = Variable(torch.FloatTensor(s_mtoken_mask).cuda())

            scores = self.model.forward(token_ids, token_mask, entity_ids, entity_mask, p_e_m,
                                        gold=true_pos.view(-1, 1))
            scores = scores.cpu().data.numpy()


            if (topk == 1):
                pred_ids = np.argmax(scores, axis=1)
                pred_entities = [m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1
                                 else (m['selected_cands']['named_cands'][0] if m['selected_cands']['mask'][0] == 1 else 'NIL')
                                 for (i, m) in zip(pred_ids, batch)]
                doc_names = [m['doc_name'] for m in batch]

                for dname, entity in zip(doc_names, pred_entities):
                    predictions[dname].append({'pred': (entity, 0.)})
            else:
                pred_ids = np.argsort(scores, axis=1)[:,::-1]
                pred_entities = [[m['selected_cands']['named_cands'][i] if m['selected_cands']['mask'][i] == 1
                 else 'NIL' for i in ids]
                 for (ids, m) in zip(pred_ids, batch)]
                doc_names = [m['doc_name'] for m in batch]
                for dname, entities in zip(doc_names, pred_entities):
                    while len(entities)>=2 and entities[-1]=='NIL' and entities[-2]=='NIL':
                        del entities[-1]
                    predictions[dname].append({'pred': (entities, 0.)})

        return predictions