class ELMoVectors(object):
    def __init__(self, size_elmo, device):
        self.size_elmo = size_elmo
        self.device = device
        self.model = Elmo(options_files[size_elmo],
                          weight_files[size_elmo],
                          1,
                          dropout=0.,
                          requires_grad=False)
        self.model.to(device)

    def get_embedding_size(self):
        return elmo_emb_size[self.size_elmo]

    def transform(self, X):
        # split all text by sentence for character embeding of a sentence
        X = self.tokenize(X)
        word_token = batch_to_ids(X).to(self.device)
        #word_emb = torch.LongTensor(word_emb).to(self.device)
        word_emb = self.model(word_token)

        # del useless varaibles
        del word_token

        return word_emb['elmo_representations'][0]

    def tokenize(self, X):
        for i in range(len(X)):
            X[i] = X[i].split(' ')
        return X
Beispiel #2
0
class LayerContextWordEmbeddings(LayerBase):
    """LayerWordEmbeddings implements word embeddings."""
    def __init__(self,
                 word_seq_indexer,
                 gpu,
                 freeze_word_embeddings=False,
                 pad_idx=0):
        super(LayerContextWordEmbeddings, self).__init__(gpu)
        self.embeddings_dim = 1024
        self.output_dim = 1024
        print("Loading ELMo weights...")
        options_file = "embeddings/newscor.lower.elmo.options.json"
        weight_file = "embeddings/newscor.lower.elmo.weights.hdf5"
        device = torch.device("cuda:" + str(gpu) if (
            torch.cuda.is_available() and gpu > -1) else "cpu")
        self.elmo = Elmo(options_file, weight_file, 2, dropout=0)
        self.elmo = self.elmo.to(device)  # Using cuda
        print("ELMo weights loaded")

    def is_cuda(self):
        return self.embeddings.weight.is_cuda

    def forward(self, word_sequences):
        #print("Creating ELMo weights...", word_sequences.shape)
        character_ids = batch_to_ids(word_sequences)
        if (self.gpu > -1):
            device = torch.device("cuda:" + str(self.gpu) if (
                torch.cuda.is_available() and self.gpu > -1) else "cpu")
            character_ids = character_ids.to(device)
        embeddings = self.elmo(character_ids)
        word_embeddings_feature = embeddings['elmo_representations'][0]
        #print("ELMo weights created")
        return word_embeddings_feature
Beispiel #3
0
class ElmoClass():
    def __init__(self, device):
        self.device = device
        bioelmo_options_file = "/home/xiatingyu/KGMedNLI-master/data/dataset/bioelmo/biomed_elmo_options.json"
        bioelmo_weight_file = "/home/xiatingyu/KGMedNLI-master/data/dataset/bioelmo/biomed_elmo_weights.hdf5"
        # Compute two different representation for each token.
        # Each representation is a linear weighted combination for the
        # 3 layers in ELMo (i.e., charcnn, the outputs of the two BiLSTM))
        self.model = Elmo(bioelmo_options_file,
                          bioelmo_weight_file,
                          2,
                          dropout=0)
        self.model = self.model.to(self.device)

    def get_embeddings(self, data, max_length, embedding_dim):
        character_ids = batch_to_ids(data).to(self.device)
        elmo_output = self.model(character_ids)
        batch_size = len(data)
        embedding = torch.zeros(batch_size,
                                max_length,
                                embedding_dim,
                                dtype=torch.float).to(self.device)
        mask = torch.ones(batch_size, max_length,
                          dtype=torch.float).to(self.device)
        for idx, temp in enumerate(elmo_output['elmo_representations'][0]):
            embedding[idx][:len(data[idx])] = temp[:len(data[idx])]
            mask[idx][len(data[idx]):] = 0.0

        return embedding, mask
Beispiel #4
0
def embed_corpus_with_elmo(corpus_name="ag_news",
                           document_size=4000,
                           language_model="elmo"):
    from allennlp.modules.elmo import Elmo, batch_to_ids
    # code from https://github.com/allenai/allennlp/issues/2245
    options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
    weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

    model = Elmo(options_file, weight_file, 1, dropout=0)
    model.eval()
    model = model.to(torch.device("cuda"))
    tokens = []
    embeddings = []
    corpus = get_corpus(corpus_name, document_size)
    for doc in tqdm(corpus):
        token, ids = doc.split(), batch_to_ids([doc.split()])
        ids = ids.cuda(torch.device('cuda'))
        with torch.no_grad():
            hidden_states = model(ids)
        embedding = hidden_states["elmo_representations"][0][0]
        embedding = embedding.detach().cpu().numpy()
        tokens.append(token)
        embeddings.append(embedding)
    with open(f"{corpus_name}.{language_model}.pk", "wb") as f:
        pickle.dump({
            "tokens": tokens,
            "embeddings": embeddings
        },
                    f,
                    protocol=4)
Beispiel #5
0
class ElmoImageTextDataset(torch.utils.data.Dataset):
    def __init__(self, posts: List[Dict[str, Any]],
                 labels_map: Dict[str, Dict[str,
                                            int]], dictionary: Dictionary):
        self.posts = list(
            map(lambda post: parse_post(post, image_retriever="pretrained"),
                posts))
        self.labels_map = labels_map
        self.dictionary = dictionary
        options_file = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
        weight_file = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
        self.elmo = Elmo(options_file, weight_file, 2, dropout=0)
        self.elmo = self.elmo.to(device)

        i = 0
        # Preprocess posts data
        for post_id, _ in enumerate(self.posts):
            # Map str label to integer
            for label in self.posts[post_id]['label'].keys():
                self.posts[post_id]['label'][label] = self.labels_map[label][
                    self.posts[post_id]['label'][label]]

            # Convert caption to list of token indices
            self.posts[post_id]['caption'] += '.'
            character_ids = batch_to_ids(
                [self.posts[post_id]['caption'].split(" ")])
            character_ids = character_ids.to(
                device)  # (len(batch), max sentence length, max word length).
            x = self.elmo(character_ids)
            self.posts[post_id]['caption'] = x['elmo_representations'][0]
            i += 1

    def __len__(self) -> int:
        return len(self.posts)

    def __getitem__(self, i: int) -> Dict[str, Any]:
        output = self.posts[i]
        return output
X_test = X_test.transform()
X_test = split_annotated_documents(X_test)

x_test_text, ner_test_tags, x_test_tokens = annotated_docs_to_tokens(X_test)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if not os.path.exists(
        os.path.join(paths.multitask_folder, 'text_test_elmo_split.pkl')):
    # elmo embeddings
    options_file = paths.elmo_options
    weight_file = paths.elmo_weights
    ELMO_folder = paths.elmo_folder
    elmo_dim = params.elmo_dim
    elmo = Elmo(options_file, weight_file, 2, dropout=0)
    elmo.to(device)
    with torch.no_grad():
        text_test = get_elmo_representation(x_test_text,
                                            elmo,
                                            elmo_dim=params.elmo_dim,
                                            device=device)
    with open(os.path.join(paths.multitask_folder, 'text_test_elmo_split.pkl'),
              'wb+') as f:
        pickle.dump(text_test, f)
else:
    with open(os.path.join(paths.multitask_folder, 'text_test_elmo_split.pkl'),
              'rb+') as f:
        text_test = pickle.load(f)

test(paths,
     params,
Beispiel #7
0
class SLU:
    def __init__(self, config, model_dir, device=None):
        self.config = config
        self.model_dir = model_dir
        self.log_file = os.path.join(model_dir, 'log.csv')

        self.device = get_device(device)

        self.slu_cls = getattr(modules, config['model']['name'])
        self.slu = self.slu_cls(config['model'])

        self.use_elmo = config.get("use_elmo", False)
        if self.use_elmo:
            option_file = config["elmo"]["option_file"]
            weight_file = config["elmo"]["weight_file"]
            self.elmo = Elmo(option_file, weight_file, 1, dropout=0)
            self.slu.elmo_scalar_mixes = nn.ModuleList(self.elmo._scalar_mixes)

            if len(config["elmo"].get("checkpoint", "")) > 0:
                self.elmo._elmo_lstm = torch.load(
                    config["elmo"]["checkpoint"]).elmo
                for param in self.elmo._elmo_lstm.parameters():
                    param.requires_grad_(False)

            self.elmo.to(self.device)

        self.slu.to(self.device)

    def prepare_training(self, batch_size, data_engine, collate_fn):
        self.train_data_loader = DataLoader(data_engine,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=1,
                                            drop_last=True,
                                            collate_fn=collate_fn,
                                            pin_memory=True)

        self.parameters = filter(lambda p: p.requires_grad,
                                 self.slu.parameters())
        self.optimizer = build_optimizer(self.config["optimizer"],
                                         self.parameters,
                                         self.config["learning_rate"])

        with open(self.log_file, 'w') as fw:
            fw.write(
                "epoch,train_loss,train_f1,valid_loss,valid_f1,test_loss,test_f1\n"
            )

    def prepare_testing(self, batch_size, data_engine, collate_fn):
        self.test_data_loader = DataLoader(data_engine,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           num_workers=1,
                                           drop_last=False,
                                           collate_fn=collate_fn,
                                           pin_memory=True)

    def train(self,
              epochs,
              batch_size,
              data_engine,
              valid_data_engine=None,
              test_data_engine=None,
              checkpoint=True):
        collate_fn = getattr(data_engine,
                             self.config.get("collate_fn", "collate_fn_asr"))
        self.prepare_training(batch_size, data_engine, collate_fn)

        run_batch_fn = getattr(self,
                               self.config.get("run_batch_fn", "run_batch"))

        for idx in range(1, epochs + 1):
            epoch_loss = 0
            epoch_acc = 0.0
            batch_amount = 0

            pbar = tqdm(self.train_data_loader,
                        desc="Iteration",
                        ascii=True,
                        dynamic_ncols=True)

            for b_idx, batch in enumerate(pbar):
                loss, logits = run_batch_fn(batch, testing=False)
                epoch_loss += loss.item()
                batch_amount += 1
                y_true = batch[data_engine.label_idx]
                y_pred = logits.detach().cpu().max(dim=1)[1].numpy()
                epoch_acc += (y_true == y_pred).sum() / len(y_true)
                pbar.set_postfix(Loss="{:.5f}".format(epoch_loss /
                                                      batch_amount),
                                 Acc="{:.4f}".format(epoch_acc / batch_amount))

            epoch_loss /= batch_amount
            epoch_acc /= batch_amount
            print_time_info(
                "Epoch {} finished, training loss {}, acc {}".format(
                    idx, epoch_loss, epoch_acc))

            valid_loss, valid_acc, _, _ = self.test(batch_size,
                                                    valid_data_engine)
            test_loss, test_acc = -1.0, -1.0
            if test_data_engine is not None:
                test_loss, test_acc, _, _ = self.test(batch_size,
                                                      test_data_engine)
            with open(self.log_file, 'a') as fw:
                fw.write(f"{idx},{epoch_loss},{epoch_acc},"
                         f"{valid_loss},{valid_acc},{test_loss},{test_acc}\n")

            if checkpoint:
                print_time_info("Epoch {}: save model...".format(idx))
                self.save_model(self.model_dir, idx)

    def test(self, batch_size, data_engine, report=False, verbose=False):
        collate_fn = getattr(
            data_engine, self.config.get("collate_fn_test", "collate_fn_asr"))
        self.prepare_testing(batch_size, data_engine, collate_fn)

        run_batch_fn = getattr(self,
                               self.config.get("run_batch_fn", "run_batch"))

        test_probs = []
        all_y_true, all_y_pred = [], []
        test_acc = 0.0
        with torch.no_grad():
            test_loss = 0
            batch_amount = 0
            for b_idx, batch in enumerate(tqdm(self.test_data_loader)):
                loss, logits = run_batch_fn(batch, testing=True)
                test_loss += loss.item()
                batch_amount += 1
                y_true = batch[data_engine.label_idx]
                y_pred = logits.detach().cpu().max(dim=1)[1].numpy()
                test_acc += (y_true == y_pred).sum() / len(y_true)
                all_y_true += list(y_true)
                all_y_pred += list(y_pred)

            test_loss /= batch_amount
            test_acc /= batch_amount
            print_time_info("testing finished, testing loss {}, acc {}".format(
                test_loss, test_acc))

        if report:
            metrics = classification_report(
                np.array(all_y_true),
                np.array(all_y_pred),
                labels=list(range(len(data_engine.label_vocab.vocab))),
                target_names=data_engine.label_vocab.vocab,
                digits=3)
            print(metrics)

        if verbose:
            for i, (y_true, y_pred) in enumerate(zip(all_y_true, all_y_pred)):
                if y_true == y_pred:
                    continue
                label = data_engine.label_vocab.i2l(y_true)
                pred = data_engine.label_vocab.i2l(y_pred)
                print("{} [{}] [{}]".format(data_engine[i]["text"], label,
                                            pred))

        return test_loss, test_acc, all_y_true, all_y_pred

    def run_batch(self, batch, testing=False):
        if testing:
            self.slu.eval()
        else:
            self.slu.train()

        inputs, words, positions, labels = batch

        inputs = torch.from_numpy(inputs).to(self.device)
        labels = torch.from_numpy(labels).to(self.device)

        elmo_emb = None
        if self.use_elmo:
            char_ids = batch_to_ids(words).to(self.device)
            elmo_emb = self.elmo(char_ids)['elmo_representations'][0]

        logits = self.slu(inputs, positions, elmo_emb)

        loss = F.cross_entropy(logits, labels)

        if not testing:
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters, 1.0)
            self.optimizer.step()

        return loss, logits

    def run_batch_lattice(self, batch, testing=False):
        if testing:
            self.slu.eval()
        else:
            self.slu.train()

        inputs, words, positions, prevs, nexts, labels = batch

        inputs = torch.from_numpy(inputs).to(self.device)
        labels = torch.from_numpy(labels).to(self.device)

        elmo_emb = None
        if self.use_elmo:
            char_ids = batch_to_ids(words).to(self.device)
            elmo_emb = self.elmo(char_ids)['elmo_representations'][0]

        logits = self.slu(inputs, positions, prevs, nexts, elmo_emb)

        loss = F.cross_entropy(logits, labels)

        if not testing:
            start = time.time()
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.parameters, 1.0)
            self.optimizer.step()
            # print(f"backward takes {time.time()-start}")

        return loss, logits

    def save_model(self, model_dir, epoch, name='slu.ckpt'):
        path = os.path.join(model_dir, "{}.{}".format(name, epoch))
        torch.save(self.slu, path)
        print_time_info("Save model successfully")

    def load_model(self, model_dir, epoch=None, name='slu.ckpt'):
        if epoch is None:
            paths = glob.glob(os.path.join(model_dir, "{}.*".format(name)))
            epoch = max(
                sorted(
                    map(int, [path.strip().split('.')[-1] for path in paths])))
            print_time_info("Epoch is not specified, loading the "
                            "last epoch ({}).".format(epoch))
        path = os.path.join(model_dir, "{}.{}".format(name, epoch))
        if not os.path.exists(path):
            print_time_info("Loading failed, start training from scratch...")
        else:
            self.slu.load_state_dict(
                torch.load(path, map_location=self.device).state_dict())
            if self.use_elmo and hasattr(self.slu, "elmo_scalar_mixes"):
                self.elmo._scalar_mixes = self.slu.elmo_scalar_mixes
                self.elmo.add_module('scalar_mix_0',
                                     self.elmo._scalar_mixes[0])
            print_time_info(
                "Load model from {} successfully".format(model_dir))
        return epoch
Beispiel #8
0
class CRF_FB(nn.Module):
    def __init__(self,
                 device,
                 tag_to_ix,
                 n_layers,
                 hidden_dim,
                 hidden_dim_pp,
                 char_cnn,
                 n_chars,
                 char_cnn_filters,
                 pairwise_gate,
                 train_type="sequence",
                 normalization="weight",
                 elmo_dropout_ratio=0.,
                 dropout_ratio=0.,
                 shared_lstm=False,
                 inp_config="full",
                 pairwise_query_type='mul',
                 bilinear_dim=300,
                 elmo_dim=1024,
                 attn='multi',
                 all_test=False,
                 gate_bias=-1.,
                 monitor=None,
                 logger=None):
        super(CRF_FB, self).__init__()
        self.device = device
        self.hidden_dim = hidden_dim
        self.hidden_dim_pp = hidden_dim_pp
        self.bilinear_dim = bilinear_dim
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)
        self.monitor = monitor
        self.embedding_dim = elmo_dim
        self.normalization = normalization
        self.elmo_dropout_ratio = elmo_dropout_ratio
        self.dropout_ratio = dropout_ratio
        self.train_type = train_type.lower()
        self.n_layers = n_layers
        self.char_cnn = char_cnn
        self.pairwise_gate = pairwise_gate
        self.bilinear_inp_dim = self.embedding_dim
        self.bilinear_out_dim = hidden_dim
        self.char_cnn_highway_bias = -1.
        self.query_dim = hidden_dim
        self.attn_dim = hidden_dim
        self.inp_config = inp_config
        self.shared_lstm = shared_lstm
        self.pairwise_query_type = pairwise_query_type
        self.pairwise_bilinear_pooling = True
        self.all_test = all_test
        self.logger = logger
        self.logger.info("Pairwise Type = {}".format(self.pairwise_query_type))

        if self.inp_config != "w2v":
            self.elmo = Elmo(ELMO_OPTIONS_FILE,
                             ELMO_WEIGHT_FILE,
                             1,
                             requires_grad=False,
                             dropout=self.elmo_dropout_ratio)
            self.elmo.to(self.device)

        self.act = nn.ELU()

        self.layer_norm = nn.LayerNorm(self.embedding_dim)

        if self.train_type != "no_unary":
            self.logger.info("Unary Config")
            self.lstm = nn.LSTM(self.embedding_dim,
                                self.hidden_dim,
                                num_layers=self.n_layers,
                                dropout=self.dropout_ratio,
                                bidirectional=True).to(device=device)

            self.unary_fc = weight_norm(nn.Linear(2 * hidden_dim,
                                                  2 * hidden_dim,
                                                  bias=True).to(device=device),
                                        dim=None)
            self.init_parameters(self.unary_fc, 'relu')

            self.out_dropout_u_fc = nn.Dropout(self.dropout_ratio)
            self.out_dropout_u_skip = nn.Dropout(self.dropout_ratio)

            self.hidden2tag = weight_norm(nn.Linear(
                2 * hidden_dim, self.tagset_size).to(device=device),
                                          dim=None)
            self.init_parameters(self.hidden2tag, 'linear')

            tran_init = torch.empty(self.tagset_size,
                                    self.tagset_size,
                                    dtype=torch.float,
                                    requires_grad=True)
            torch.nn.init.normal_(tran_init, mean=0.0, std=1.)
            self.transitions = nn.Parameter(tran_init.to(device=device))
            self.transitions.data[:, tag_to_ix[DatasetPreprosessed.
                                               __START_TAG__]] = -100.
            self.transitions.data[
                tag_to_ix[DatasetPreprosessed.__STOP_TAG__], :] = -100.

        if self.train_type != "no_pairwise":
            self.logger.info("Pairwise Config")
            if not self.shared_lstm:
                self.logger.info("Separate LSTMs")
                self.lstm_pairwise = nn.LSTM(
                    self.embedding_dim,
                    self.hidden_dim,
                    num_layers=self.n_layers,
                    dropout=self.dropout_ratio,
                    bidirectional=True).to(device=device)
            else:
                self.logger.info("Shared LSTM")
            self.U = weight_norm(nn.Linear(
                2 * self.hidden_dim, self.hidden_dim_pp).to(device=device),
                                 dim=None)
            self.init_parameters(self.U, 'relu')
            self.V = weight_norm(nn.Linear(
                2 * self.hidden_dim, self.hidden_dim_pp).to(device=device),
                                 dim=None)
            self.init_parameters(self.V, 'relu')
            self.P = weight_norm(nn.Linear(
                self.hidden_dim_pp, self.bilinear_dim).to(device=device),
                                 dim=None)
            self.init_parameters(self.P, 'relu')
            self.pairwise_fc = weight_norm(
                nn.Linear(self.bilinear_dim, self.bilinear_dim,
                          bias=True).to(device=device),
                dim=None)
            self.init_parameters(self.pairwise_fc, 'relu')
            self.dropout_p_mul = nn.Dropout(self.dropout_ratio)
            self.out_dropout_p_fc = nn.Dropout(self.dropout_ratio)
            self.out_dropout_p_skip = nn.Dropout(self.dropout_ratio)
            self.hidden2tag_pp = weight_norm(nn.Linear(
                self.bilinear_dim, self.tagset_size**2).to(device=device),
                                             dim=None)
            self.init_parameters(self.hidden2tag_pp, 'linear')

        self.__start__ = torch.tensor(
            self.tag_to_ix[DatasetPreprosessed.__START_TAG__],
            dtype=torch.long).to(device=device)
        self.__stop__ = torch.tensor(
            self.tag_to_ix[DatasetPreprosessed.__STOP_TAG__],
            dtype=torch.long).to(device=device)

    def init_parameters(self, sub_module, nonlinearity="relu"):
        nn.init.xavier_uniform_(sub_module.weight,
                                gain=nn.init.calculate_gain(nonlinearity))
        if sub_module.bias is not None:
            nn.init.constant_(sub_module.bias, 0.)

    def init_hidden(self):
        hidden = torch.zeros((2 * self.n_layers, 1, self.hidden_dim),
                             dtype=torch.float,
                             device=self.device,
                             requires_grad=True)
        cell = torch.zeros((2 * self.n_layers, 1, self.hidden_dim),
                           dtype=torch.float,
                           device=self.device,
                           requires_grad=True)
        return (hidden, cell)

    def init_pairwise_hidden(self):
        hidden = torch.zeros((2 * self.n_layers, 1, self.hidden_dim),
                             dtype=torch.float,
                             device=self.device,
                             requires_grad=True)
        cell = torch.zeros((2 * self.n_layers, 1, self.hidden_dim),
                           dtype=torch.float,
                           device=self.device,
                           requires_grad=True)
        return (hidden, cell)

    def forward_alg_unary(self, feats):
        init_alphas = torch.full((1, self.tagset_size),
                                 -100.,
                                 dtype=torch.float,
                                 requires_grad=True).to(device=self.device)
        init_alphas[0][self.__start__] = 0.

        forward_var = init_alphas

        for i, feat in enumerate(feats):
            alphas_t = []
            for next_tag in range(self.tagset_size):
                emit_score = feat[next_tag].view(1, -1).expand(
                    1, self.tagset_size)
                trans_score = self.transitions[:, next_tag].view(1, -1)
                assert emit_score.size() == trans_score.size()
                next_tag_var = forward_var + emit_score + trans_score
                alphas_t.append(utils.log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[:, self.__stop__].view(
            1, -1)
        alpha = utils.log_sum_exp(terminal_var)
        return alpha

    def forward_alg_pairwise(self, feats):
        init_alphas = torch.full((1, self.tagset_size),
                                 0,
                                 dtype=torch.float,
                                 requires_grad=True).to(device=self.device)
        forward_var = init_alphas

        for feat in feats:
            alphas_t = []
            for next_tag in range(self.tagset_size):
                trans_score = feat.view(self.tagset_size,
                                        self.tagset_size)[:, next_tag].view(
                                            1, -1)
                next_tag_var = forward_var + trans_score
                alphas_t.append(utils.log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
            terminal_var = forward_var
        alpha = utils.log_sum_exp(terminal_var)
        return alpha

    def score_sentence_unary(self, feats, tags):
        score = torch.tensor(0., dtype=torch.float,
                             requires_grad=True).to(device=self.device)
        tags = tags.squeeze().type(torch.long)
        start = self.__start__.unsqueeze(0)
        if tags.dim() == 0:
            tags = tags.unsqueeze(0)

        tags = torch.cat([start, tags])
        for i, feat in enumerate(feats):
            score = score + feat[tags[i + 1]] + self.transitions[tags[i],
                                                                 tags[i + 1]]
        score = score + self.transitions[tags[-1], self.__stop__]
        return score

    def score_sentence_pairwise(self, feats, tags):
        score = torch.tensor(0., dtype=torch.float,
                             requires_grad=True).to(device=self.device)
        tags = tags.squeeze().type(torch.long)
        start = self.__start__.unsqueeze(0)
        stop = self.__stop__.unsqueeze(0)
        if tags.dim() == 0:
            tags = tags.unsqueeze(0)

        tags = torch.cat([start, tags, stop]).to(device=self.device)
        for i, feat in enumerate(feats):
            score = score + feat.view(self.tagset_size,
                                      self.tagset_size)[tags[i], tags[i + 1]]
        return score

    def get_unary_lstm_features(self, sentence, iter):
        self.hidden = self.init_hidden()
        embeds = sentence.view(-1, 1, self.embedding_dim)

        lstm_out, self.hidden = self.lstm(embeds, self.hidden)

        return lstm_out

    def get_pairwise_lstm_features(self, sentence, iter):
        self.hidden = self.init_pairwise_hidden()
        embeds = sentence.view(-1, 1, self.embedding_dim)

        lstm_out, self.hidden = self.lstm_pairwise(embeds, self.hidden)

        return lstm_out

    def get_unary_features(self, lstm_out, iter):
        fc_inp = lstm_out[1:-1].squeeze(1)
        fc_out = self.out_dropout_u_fc(self.act(
            self.unary_fc(fc_inp))) + self.out_dropout_u_skip(fc_inp)
        feats = self.hidden2tag(fc_out)

        return feats

    def get_pairwise_features(self, lstm_out, iter):
        U_inp = lstm_out[:-1].squeeze(1)
        V_inp = lstm_out[1:].squeeze(1)

        if self.pairwise_bilinear_pooling:
            U = self.U(U_inp)
            V = self.V(V_inp)
            h = torch.mul(U, V)
            fc_inp = self.act(self.P(self.dropout_p_mul(h)))
            fc_out = self.out_dropout_p_fc(self.act(
                self.pairwise_fc(fc_inp))) + self.out_dropout_p_skip(fc_inp)
            feats = self.hidden2tag_pp(fc_out)
        else:
            U = U_inp
            V = V_inp
            fc_inp = torch.cat([U, V, torch.mul(U, V)], 1)
            fc_out = self.out_dropout_p_fc(self.act(
                self.pairwise_fc(fc_inp))) + self.out_dropout_p_skip(fc_inp)
            feats = self.hidden2tag_pp(fc_out)

        return feats

    def viterbi_decode(self, feats, feats_pp):
        backpointers = []
        init_vvars = torch.full((1, self.tagset_size),
                                -100.,
                                dtype=torch.float).to(device=self.device)
        init_vvars[0][self.__start__] = 0

        forward_var = init_vvars
        for feat, feat_pp in zip(feats, feats_pp[:-1]):
            bptrs_t = []
            viterbivars_t = []
            for next_tag in range(self.tagset_size):
                pairwise_transition_score = feat_pp.view(
                    self.tagset_size, self.tagset_size)[:,
                                                        next_tag].view(1, -1)
                next_tag_var = forward_var + self.transitions[:, next_tag].view(
                    1, -1) + pairwise_transition_score
                best_tag_id = utils.argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)
        terminal_var = forward_var + feats_pp[-1].view(
            self.tagset_size, self.tagset_size)[:, self.__stop__].view(1, -1)
        best_tag_id = utils.argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        start = best_path.pop()
        best_path.reverse()
        return path_score, best_path

    def viterbi_decode_unary(self, feats):
        backpointers = []
        init_vvars = torch.full((1, self.tagset_size),
                                -100.,
                                dtype=torch.float).to(device=self.device)
        init_vvars[0][self.__start__] = 0

        forward_var = init_vvars
        for feat in feats:
            bptrs_t = []  # holds the backpointers for this step
            viterbivars_t = []  # holds the viterbi variables for this step
            for next_tag in range(self.tagset_size):
                next_tag_var = forward_var + self.transitions[:,
                                                              next_tag].view(
                                                                  1, -1)
                best_tag_id = utils.argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            forward_var = (feat + torch.cat(viterbivars_t)).view(1, -1)
            backpointers.append(bptrs_t)
        terminal_var = forward_var + self.transitions[:, self.__stop__].view(
            1, -1)
        best_tag_id = utils.argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        start = best_path.pop()
        best_path.reverse()
        return path_score, best_path

    def viterbi_decode_pairwise(self, feats_pp):
        backpointers = []
        init_vvars = torch.full((1, self.tagset_size),
                                -100.,
                                dtype=torch.float).to(device=self.device)
        init_vvars[0][self.__start__] = 0

        forward_var = init_vvars
        for feat_pp in feats_pp[:-1]:
            bptrs_t = []  # holds the backpointers for this step
            viterbivars_t = []  # holds the viterbi variables for this step
            for next_tag in range(self.tagset_size):
                pairwise_transition_score = feat_pp.view(
                    self.tagset_size, self.tagset_size)[:,
                                                        next_tag].view(1, -1)
                next_tag_var = forward_var + pairwise_transition_score
                best_tag_id = utils.argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            forward_var = (torch.cat(viterbivars_t)).view(1, -1)
            backpointers.append(bptrs_t)
        terminal_var = forward_var + feats_pp[-1].view(
            self.tagset_size, self.tagset_size)[:, self.__stop__].view(1, -1)
        best_tag_id = utils.argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        start = best_path.pop()
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sequence, words, chars, tags, iter):
        if self.inp_config == "full":
            sequence = self.elmo(sequence)["elmo_representations"][0].squeeze(
                0)
            inp_feats = torch.cat((sequence, words), 2)
        elif self.inp_config == "w2v":
            inp_feats = words
        else:
            inp_feats = self.elmo(sequence)["elmo_representations"][0].squeeze(
                0)
        inp_feats = self.layer_norm(inp_feats)

        if self.train_type == "no_unary":
            pairwise_inp_feats = self.get_pairwise_lstm_features(
                inp_feats, iter)
            pairwise_feats = self.get_pairwise_features(
                pairwise_inp_feats, iter)
            forward_score_pp = self.forward_alg_pairwise(pairwise_feats)
            gold_score_pp = self.score_sentence_pairwise(
                pairwise_feats,
                tags.type(torch.IntTensor).to(device=self.device))
            loss_pairwise = forward_score_pp - gold_score_pp
            loss = loss_pairwise

            return loss, loss_pairwise, loss_pairwise

        elif self.train_type == "no_pairwise":
            unary_inp_feats = self.get_unary_lstm_features(inp_feats, iter)
            unary_feats = self.get_unary_features(unary_inp_feats, iter)
            forward_score_u = self.forward_alg_unary(unary_feats)
            gold_score_u = self.score_sentence_unary(
                unary_feats,
                tags.type(torch.IntTensor).to(device=self.device))
            loss_unary = forward_score_u - gold_score_u
            loss = loss_unary

            return loss, loss_unary, loss_unary

        else:
            unary_inp_feats = self.get_unary_lstm_features(inp_feats, iter)
            if self.shared_lstm:
                pairwise_inp_feats = unary_inp_feats
            else:
                pairwise_inp_feats = self.get_pairwise_lstm_features(
                    inp_feats, iter)

            unary_feats = self.get_unary_features(unary_inp_feats, iter)
            pairwise_feats = self.get_pairwise_features(
                pairwise_inp_feats, iter)
            forward_score_u = self.forward_alg_unary(unary_feats)
            gold_score_u = self.score_sentence_unary(
                unary_feats,
                tags.type(torch.IntTensor).to(device=self.device))
            loss_unary = forward_score_u - gold_score_u

            forward_score_pp = self.forward_alg_pairwise(pairwise_feats)
            gold_score_pp = self.score_sentence_pairwise(
                pairwise_feats,
                tags.type(torch.IntTensor).to(device=self.device))
            loss_pairwise = forward_score_pp - gold_score_pp

            loss = loss_unary + loss_pairwise

        return loss, loss_unary, loss_pairwise

    def forward(self, sequence, words, chars):
        if self.inp_config == "full":
            sequence = self.elmo(sequence)["elmo_representations"][0].squeeze(
                0)
            inp_feats = torch.cat((sequence, words), 2)
        elif self.inp_config == "w2v":
            inp_feats = words
        else:
            inp_feats = self.elmo(sequence)["elmo_representations"][0].squeeze(
                0)

        inp_feats = self.layer_norm(inp_feats)

        if self.train_type == "no_unary":
            pairwise_inp_feats = self.get_pairwise_lstm_features(
                inp_feats, iter)
            pairwise_feats = self.get_pairwise_features(
                pairwise_inp_feats, iter)
            score, tag_seq = self.viterbi_decode_pairwise(pairwise_feats)
        elif self.train_type == "no_pairwise":
            unary_inp_feats = self.get_unary_lstm_features(inp_feats, iter)
            unary_feats = self.get_unary_features(unary_inp_feats, iter)
            score, tag_seq = self.viterbi_decode_unary(unary_feats)
        else:
            unary_inp_feats = self.get_unary_lstm_features(inp_feats, iter)
            if self.shared_lstm:
                pairwise_inp_feats = unary_inp_feats
            else:
                pairwise_inp_feats = self.get_pairwise_lstm_features(
                    inp_feats, iter)

            unary_feats = self.get_unary_features(unary_inp_feats, iter)
            pairwise_feats = self.get_pairwise_features(
                pairwise_inp_feats, iter)
            score, tag_seq = self.viterbi_decode(unary_feats, pairwise_feats)
            if self.all_test:
                score_u, tag_seq_u = self.viterbi_decode_unary(unary_feats)
                score_p, tag_seq_p = self.viterbi_decode_pairwise(
                    pairwise_feats)
                return score, tag_seq, score_u, tag_seq_u, score_p, tag_seq_p
        return score, tag_seq
Beispiel #9
0
from __future__ import print_function, division
from allennlp.modules.elmo import Elmo, batch_to_ids
import torch
import time

device = "cuda" if torch.cuda.is_available() else "cpu"


options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
elmo = Elmo(options_file=options_file, weight_file=weight_file,
	do_layer_norm=False, dropout=0.0, num_output_representations=1)
elmo = elmo.to(device)

# tokens = {""}
# elmo_tokens = tokens.pop("elmo", None)
# elmo_representations = elmo(elmo_tokens)["elmo_representations"]

start_time = time.time()
sentences = [['First', 'sentence', '.'], ['Another', '.']]
character_ids = batch_to_ids(sentences).to(device)
print("character_ids", character_ids.shape, type(character_ids))

embeddings = elmo(character_ids)
elapsed_time = time.time() - start_time
print("time={}".format(elapsed_time))

print(type(embeddings), embeddings.keys())
elmo_representations = embeddings['elmo_representations']
print(len(elmo_representations))
for i in range(len(elmo_representations)):
Beispiel #10
0
def train_node2vec(paths, params):
    dump_process_pkl = paths.dump_process
    dump_context_dict = paths.dump_context_dict
    dump_context_list = paths.dump_context_list
    dump_walks = paths.dump_walks
    save_model_path = paths.node2vec_base
    embedding_txt = paths.embedding_text
    embedding_temp = paths.embedding_temp
    embedding = paths.embedding
    mesh_graph_file = paths.MeSH_graph_disease

    if not params.randomize:
        np.random.seed(5)
        torch.manual_seed(5)
        random.seed(5)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    writer = SummaryWriter()

    # ----------- Random walk --------------------
    directed_graph = False

    if not os.path.exists(dump_walks):
        num_walks = 30
        walk_length = 10
        nx_G = read_graph(mesh_graph_file, directed_graph)
        G = Graph(nx_G, is_directed=directed_graph, p=params.p, q=params.q)
        G.preprocess_transition_probs()
        walks = G.simulate_walks(num_walks, walk_length)
        with open(dump_walks, 'wb') as f:
            pickle.dump(walks, f)
    else:
        with open(dump_walks, 'rb') as f:
            walks = pickle.load(f)

    if os.path.exists(dump_process_pkl):
        with open(dump_process_pkl, 'rb') as f:
            vocab = pickle.load(f)
    else:
        vocab = Vocabulary(lower=False)
        vocab.add_documents(walks)
        vocab.build()

        with open(dump_process_pkl, 'wb') as f:
            pickle.dump(vocab, f)

    # ---------- build embedding model ----------
    mesh_file = paths.MeSH_file
    ELMO_folder = paths.elmo_folder
    options_file = paths.elmo_options
    weight_file = paths.elmo_weights

    elmo = Elmo(options_file, weight_file, 2, dropout=0)
    elmo.to(device)

    mesh_graph = nx.read_gpickle(mesh_graph_file)
    mesh_graph = mesh_graph.to_undirected()

    mesh_dict = read_mesh_file(mesh_file)

    # Get the list of nodes (idx 0 is '<pad>')
    node_list = list(vocab.vocab.keys())

    # create weight matrix by using node_list order(which correspond to original vocab index order)
    elmo_embedding_dim = 1024
    if not os.path.exists(os.path.join(ELMO_folder, 'elmo_weights')):
        weight_list = []
        for idx, i in enumerate(node_list):
            if i in mesh_dict:
                node_idx = vocab.token_to_id(i)
                scope_note = mesh_dict[i].scope_note
                character_ids = batch_to_ids(scope_note).to(device)
                elmo_embeddings = elmo(character_ids)
                embeddings = elmo_embeddings['elmo_representations'][0]
                mask = elmo_embeddings['mask']
                embeddings = embeddings * mask.unsqueeze(2).expand(
                    mask.shape[0], mask.shape[1], embeddings.shape[2]).float()
                embeddings = embeddings.mean(dim=0).mean(dim=0)  # average
                weight_list.append(embeddings.cpu())
            else:
                weight_list.append(torch.zeros(elmo_embedding_dim))

        with open(os.path.join(ELMO_folder, 'elmo_weights'), 'wb') as f:
            pickle.dump(weight_list, f)
    else:
        with open(os.path.join(ELMO_folder, 'elmo_weights'), 'rb') as f:
            weight_list = pickle.load(f)

    weight = torch.stack(weight_list, dim=0)

    # ---------- train SkipGram -----------------
    epochs = params.epochs
    batch_size = params.batch_size
    window = params.window
    num_neg_sample = params.num_neg_sample
    writer = SummaryWriter()

    # use transformation only once, i.e either during creating the context dict and list or during training
    if not os.path.exists(dump_context_dict):
        l, d = multiprocess(walks, window=window, transform=vocab.doc2id)
        with open(dump_context_dict, 'wb') as f:
            pickle.dump(d, f)
        with open(dump_context_list, 'wb') as f:
            pickle.dump(l, f)
    else:
        with open(dump_context_dict, 'rb') as f:
            d = pickle.load(f)
        with open(dump_context_list, 'rb') as f:
            l = pickle.load(f)

    # here transformation is required we will directly sample the index
    sample_table = negative_sampling_table(vocab.token_counter(),
                                           transform=vocab.token_to_id)
    neg_sample = np.random.choice(sample_table, size=(len(l), num_neg_sample))

    context_data = ContextData(l, d, neg_sample, n_sample=5, transform=None)
    context_dataloader = DataLoader(context_data,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    pin_memory=True,
                                    num_workers=6)

    model_embedding = SkipGramModified(len(vocab.vocab),
                                       embedding_size=elmo_embedding_dim,
                                       weight=weight)
    model_embedding.to(device)
    optimizer_FC = torch.optim.Adam(list(model_embedding.parameters()),
                                    lr=0.005)  #+list(model_fc.parameters()

    train(model_embedding,
          optimizer_FC,
          context_dataloader,
          epochs,
          device,
          neg_sample,
          n_sample=num_neg_sample,
          writer=writer,
          save_path=save_model_path,
          l=l,
          d=d,
          vocab=vocab,
          batch_size=batch_size)

    node_idx = []
    for item in node_list:
        node_idx.append(vocab.token_to_id(item))

    x = torch.tensor(node_idx, device=device)
    y = torch.zeros(x.shape, device=device)
    z = torch.zeros(x.shape, device=device)

    x, y, z = model_embedding(x, y, z)

    word_embeddings = x.cpu().detach().numpy()

    sorted_vocab_tuple = sorted(vocab.vocab.items(), key=lambda kv: kv[1])

    with open(embedding_txt, 'w') as f:
        for idx, item in enumerate(sorted_vocab_tuple):
            if item[0] == '\n':
                continue
            f.write(item[0] + ' ' +
                    ' '.join([str(i) for i in word_embeddings[idx]]) + '\n')

    glove_file = datapath(embedding_txt)
    temp_file = get_tmpfile(embedding_temp)
    _ = glove2word2vec(glove_file, temp_file)

    wv = KeyedVectors.load_word2vec_format(temp_file)
    wv.save(embedding)

    writer.close()
Beispiel #11
0
class WordRep(nn.Module):
    device = 'cpu'

    def __init__(self, data, tdevice='cuda:2'):
        super(WordRep, self).__init__()
        print("build word representation...")
        self.data = data
        self.device = tdevice

        # ELMo setup TODO: make this not hard-coded
        options_file = "/u/sjeblee/research/data/elmo/weights/elmo_2x4096_512_2048cnn_2xhighway_options.json"
        weight_file = "/u/sjeblee/research/data/elmo/weights/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5"

        self.gpu = data.HP_gpu

        self.use_char = data.use_char
        self.batch_size = data.HP_batch_size
        self.char_hidden_dim = 0
        self.char_all_feature = False
        if self.use_char:
            self.char_hidden_dim = data.HP_char_hidden_dim
            self.char_embedding_dim = data.char_emb_dim
            if data.char_feature_extractor == "CNN":
                self.char_feature = CharCNN(data.char_alphabet.size(), data.pretrain_char_embedding, self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            elif data.char_feature_extractor == "LSTM":
                self.char_feature = CharBiLSTM(data.char_alphabet.size(), data.pretrain_char_embedding, self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            elif data.char_feature_extractor == "GRU":
                self.char_feature = CharBiGRU(data.char_alphabet.size(), data.pretrain_char_embedding, self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            elif data.char_feature_extractor == "ALL":
                self.char_all_feature = True
                self.char_feature = CharCNN(data.char_alphabet.size(), data.pretrain_char_embedding, self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu, device=data.device)
                self.char_feature_extra = CharBiLSTM(data.char_alphabet.size(), data.pretrain_char_embedding, self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            else:
                print("Error char feature selection, please check parameter data.char_feature_extractor (CNN/LSTM/GRU/ALL).")
                exit(0)
        self.embedding_dim = data.word_emb_dim # Should be 1024 for Elmo

        self.drop = nn.Dropout(data.HP_dropout)

        '''
        self.word_embedding = nn.Embedding(data.word_alphabet.size(), self.embedding_dim)
        if data.pretrain_word_embedding is not None:
            self.word_embedding.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embedding.weight.data.copy_(torch.from_numpy(self.random_embedding(data.word_alphabet.size(), self.embedding_dim)))
        '''

        # Load ELMo
        self.word_embedding = Elmo(options_file, weight_file, 1, dropout=0)

        '''
        self.feature_num = data.feature_num
        self.feature_embedding_dims = data.feature_emb_dims
        self.feature_embeddings = nn.ModuleList()
        for idx in range(self.feature_num):
            self.feature_embeddings.append(nn.Embedding(data.feature_alphabets[idx].size(), self.feature_embedding_dims[idx]))
        for idx in range(self.feature_num):
            if data.pretrain_feature_embeddings[idx] is not None:
                self.feature_embeddings[idx].weight.data.copy_(torch.from_numpy(data.pretrain_feature_embeddings[idx]))
            else:
                self.feature_embeddings[idx].weight.data.copy_(torch.from_numpy(self.random_embedding(data.feature_alphabets[idx].size(), self.feature_embedding_dims[idx])))
        '''
        if self.gpu:
            #self.drop = self.drop.to(data.device)
            print('Moving elmo module to:', self.device)
            self.word_embedding = self.word_embedding.to(self.device)
            '''
            for idx in range(self.feature_num):
                self.feature_embeddings[idx] = self.feature_embeddings[idx].to(data.device)
            '''

    def random_embedding(self, vocab_size, embedding_dim):
        pretrain_emb = np.empty([vocab_size, embedding_dim])
        scale = np.sqrt(3.0 / embedding_dim)
        for index in range(vocab_size):
            pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
        return pretrain_emb

    def forward(self, word_inputs, feature_inputs=None, word_seq_lengths=None, char_inputs=None, char_seq_lengths=None, char_seq_recover=None):
        """
            input:
                word_inputs: (batch_size, sent_len)
                features: list [(batch_size, sent_len), (batch_len, sent_len),...]
                word_seq_lengths: list of batch_size, (batch_size,1)
                char_inputs: (batch_size*sent_len, word_length)
                char_seq_lengths: list of whole batch_size for char, (batch_size*sent_len, 1)
                char_seq_recover: variable which records the char order information, used to recover char order
            output:
                Variable(batch_size, sent_len, hidden_dim)
        """
        #batch_size = word_inputs.size(0)
        #sent_len = word_inputs.size(1)
        #word_embs = self.word_embedding(word_inputs)
        batch_size = len(word_inputs)
        sent_len = max(len(w) for w in word_inputs)
        print('wordrep batch_size:', batch_size, 'sent_len:', sent_len)

        # ELMo word embeddings
        #print('word inputs:', word_inputs)
        print('self.device:', self.device)
        character_ids = batch_to_ids(word_inputs).to(self.device)
        embeddings = self.word_embedding(character_ids)['elmo_representations']
        #print('elmo embeddings:', len(embeddings))
        word_embs = embeddings[0]#.squeeze()
        print('word embeddings:', word_embs.size())

        word_list = [word_embs]
        '''
        for idx in range(self.feature_num):
            word_list.append(self.feature_embeddings[idx](feature_inputs[idx]))
        '''

        if self.use_char:
            # Calculate char lstm last hidden
            char_features = self.char_feature.get_last_hiddens(char_inputs, char_seq_lengths.cpu().numpy())
            char_features = char_features[char_seq_recover]
            char_features = char_features.view(batch_size, sent_len, -1)
            # Concat word and char together
            word_list.append(char_features)
            word_embs = torch.cat([word_embs, char_features], 2)
            if self.char_all_feature:
                char_features_extra = self.char_feature_extra.get_last_hiddens(char_inputs, char_seq_lengths.cpu().numpy())
                char_features_extra = char_features_extra[char_seq_recover]
                char_features_extra = char_features_extra.view(batch_size, sent_len, -1)
                # Concat word and char together
                word_list.append(char_features_extra)
        word_embs = torch.cat(word_list, 2)
        word_represent = word_embs
        #word_represent = self.drop(word_embs)
        return word_represent
    def transform(self, X, y=None):
        """ Annotates the list of `Document` objects that are provided as
            input and returns a list of `AnnotatedDocument` objects.
        """
        log.info(
            "Annotating named entities in {} documents with BiLSTM...".format(
                len(X)))

        self.model.eval()

        x_test_text, ner_test_labels, x_test_tokens = annotated_docs_to_tokens(
            X)

        elmo = Elmo(self.hparams['options_file'],
                    self.hparams['weight_file'],
                    2,
                    dropout=0)
        elmo.to(self.hparams['device'])

        att_weights = []
        text_test = []
        for idx, t in enumerate(x_test_text):
            char_id = batch_to_ids(t).to(self.hparams['device'])
            with torch.no_grad():
                elmo_emb = elmo(char_id)
            t_emb = elmo_emb['elmo_representations'][0].view(
                -1, self.hparams['elmo_dim']).detach().cpu()
            t_emb = torch.stack([
                tensor
                for tensor in t_emb if len(np.nonzero(tensor.numpy())[0]) != 0
            ],
                                dim=0)
            text_test.append(t_emb)

        y_pred = []
        with torch.no_grad():
            for batch in self.mini_batch(
                    text_test,
                    ner_test_labels,
                    p=self.ner_labels_vocab.doc2id,
                    batch_size=self.config.get_parameter(
                        'batch_size')):  # used for debugging
                x, y, sorted_idx = padding(batch)  # used for debugging
                # for x, y, sorted_index in validation_generator:
                x, y = x.to(self.hparams['device']), y.to(
                    self.hparams['device'])

                mask = torch.where(y != self.ner_labels_vocab.vocab['<pad>'], torch.tensor([1], dtype=torch.uint8,device=self.hparams['device']), \
                        torch.tensor([0], dtype=torch.uint8, device=self.hparams['device']))

                z, _ = self.model(x, y, mask)
                pred, att_weight = self.model.decode(x, mask)

                pred_unsort = [0] * x.shape[0]
                for i, j in zip(sorted_idx, pred):
                    pred_unsort[i] = j
                y_pred.extend(pred_unsort)
                att_weights.extend(att_weight.cpu().numpy())

        lengths = map(len, x_test_tokens)
        tags = inverse_transform(np.asarray(y_pred), self.ner_labels_vocab,
                                 lengths)
        print('F1: ', f1_score(ner_test_labels, tags), '\t Acc: ',
              accuracy_score(ner_test_labels, tags))
        print(classification_report(ner_test_labels, tags))

        x_pred = transform_bio_tags_to_annotated_documents(
            x_test_tokens, tags, X)

        p, r, f1 = annotation_precision_recall_f1score(x_pred, X)

        return x_pred, [x_test_tokens, att_weights]
    def fit(
            self,
            X,
            y=None,
            X_valid=None,
            char_emb_size=32,
            word_emb_size=128,
            char_lstm_units=32,
            word_lstm_units=128,
            pos_emb_size=16,
            dropout=0.5,
            batch_size=8,
            num_epochs=10,
            use_crf=False,
            use_char_emb=False,
            shuffle=False,
            use_pos_emb=False,
            hparams_1={},  # specific to config 
            hparams_2={}):  # other params
        """ Trains the NER model. The input is a list of
            `AnnotatedDocument` instances.

            An example here is a token assigned a tag (the BIO scheme).
        """

        log.info("Checking parameters...")
        self.config.set_parameters({
            "num_epochs": num_epochs,
            "dropout": dropout,
            "batch_size": batch_size,
            "char_emb_size": char_emb_size,
            "word_emb_size": word_emb_size,
            "char_lstm_units": char_lstm_units,
            "word_lstm_units": word_lstm_units,
            "pos_emb_size": pos_emb_size,
            "use_crf": use_crf,
            "use_char_emb": use_char_emb,
            "shuffle": shuffle,
            "use_pos_emb": use_pos_emb
        })

        if hparams_1:
            self.config.set_parameters(hparams_1)
        self.config.validate()
        if hparams_2:
            self.hparams.update(hparams_2)

        x_train_text, ner_train_labels, x_train_tokens = annotated_docs_to_tokens(
            X)

        elmo = Elmo(self.hparams['options_file'],
                    self.hparams['weight_file'],
                    2,
                    dropout=0)
        elmo.to(self.hparams['device'])

        text_train = []
        for idx, t in enumerate(x_train_text):
            char_id = batch_to_ids(t).to(self.hparams['device'])
            with torch.no_grad():
                elmo_emb = elmo(char_id)
            t_emb = elmo_emb['elmo_representations'][0].view(
                -1, self.hparams['elmo_dim']).detach().cpu()
            t_emb = torch.stack([
                tensor
                for tensor in t_emb if len(np.nonzero(tensor.numpy())[0]) != 0
            ],
                                dim=0)
            text_train.append(t_emb)

        self.ner_labels_vocab = Vocabulary(lower=False)
        self.ner_labels_vocab.add_documents(ner_train_labels)
        self.ner_labels_vocab.build()

        print(
            "------------------- Training BiLSTM ---------------------------")
        self.timestr = time.strftime("%Y%m%d-%H%M%S")
        params = {
            'batch_size': self.config.get_parameter('batch_size'),
            'shuffle': self.config.get_parameter('shuffle'),
            'num_workers': 1
        }

        training_set = NERSequence(text_train,
                                   ner_train_labels,
                                   preprocess=self.ner_labels_vocab.doc2id)
        training_generator = DataLoader(training_set,
                                        collate_fn=padding,
                                        **params)

        if X_valid:
            x_val_text, ner_val_labels, x_val_tokens = annotated_docs_to_tokens(
                X_valid)

            text_val = []
            for idx, t in enumerate(x_val_text):
                char_id = batch_to_ids(t).to(self.hparams['device'])
                with torch.no_grad():
                    elmo_emb = elmo(char_id)
                t_emb = elmo_emb['elmo_representations'][0].view(
                    -1, self.hparams['elmo_dim']).detach().cpu()
                t_emb = torch.stack([
                    tensor for tensor in t_emb
                    if len(np.nonzero(tensor.numpy())[0]) != 0
                ],
                                    dim=0)
                text_val.append(t_emb)

            params = {
                'batch_size': self.config.get_parameter('batch_size'),
                'shuffle': False,
                'num_workers': 12
            }

            validation_set = NERSequence(
                text_val,
                ner_val_labels,
                preprocess=self.ner_labels_vocab.doc2id)
            validation_generator = DataLoader(validation_set,
                                              collate_fn=padding,
                                              **params)

        self.model = BiLSTMCRF(self.config,
                               self,
                               batch_first=True,
                               device=self.hparams['device'])
        self.model.to(self.hparams['device'])

        if self.hparams['optimizer'] == 'adam':
            optimizer = optim.Adam(self.model.parameters(),
                                   lr=self.hparams['lr'])

        with open(self.hparams['output'], 'w+') as f:

            prev_f1 = 0.
            for epoch in range(self.config.get_parameter("num_epochs")):
                print("########## Epoch ", epoch, "##################")
                f.writelines(
                    '########## Epoch ", epoch, "##################\n')

                train_loss = []
                self.model.train()
                start = time.time()
                # text_train, ner_train_labels= Shuffle_lists(text_train, ner_train_labels) # Used for debugging
                # for batch in self.mini_batch(text_train, ner_train_labels, p=self.ner_labels_vocab.doc2id, batch_size=self.config.get_parameter('batch_size')): # Used for debugging
                #     x, y, _ = padding(batch) # used for debugging
                for x, y, _ in training_generator:
                    x, y = x.to(self.hparams['device']), y.to(
                        self.hparams['device'])

                    optimizer.zero_grad()

                    mask = torch.where(y != self.ner_labels_vocab.vocab['<pad>'], torch.tensor([1], dtype=torch.uint8,device=self.hparams['device']), \
                            torch.tensor([0], dtype=torch.uint8, device=self.hparams['device']))

                    z, _ = self.model(x, y, mask)

                    loss = z

                    loss.backward()
                    optimizer.step()
                    train_loss.append(loss.item())

                print("Epoch: ", epoch, "\tTraining Loss: ",
                      np.mean(train_loss), "\tTime: ",
                      time.time() - start)
                f.writelines(
                    f"Epoch: {epoch}\tTraining Loss: {np.mean(train_loss)}\tTime: {time.time()-start}\n"
                )

                self.model.eval()
                valid_loss = []
                y_pred = []
                start = time.time()

                with torch.no_grad():
                    # for batch in self.mini_batch(text_val,  ner_val_labels, p=self.ner_labels_vocab.doc2id, batch_size=self.config.get_parameter('batch_size')): # used for debugging
                    #     x, y, sorted_idx = padding(batch) # used for debugging
                    for x, y, sorted_idx in validation_generator:
                        x, y = x.to(self.hparams['device']), y.to(
                            self.hparams['device'])

                        mask = torch.where(y != self.ner_labels_vocab.vocab['<pad>'], torch.tensor([1], dtype=torch.uint8,device=self.hparams['device']), \
                                torch.tensor([0], dtype=torch.uint8, device=self.hparams['device']))

                        z, _ = self.model(x, y, mask)
                        pred, _ = self.model.decode(x, mask)

                        loss = z

                        valid_loss.append(loss.item())

                        pred_unsort = [0] * x.shape[0]
                        for i, j in zip(sorted_idx, pred):
                            pred_unsort[i] = j
                        y_pred.extend(pred_unsort)

                print("Epoch: ", epoch, "\tValidation Loss: ",
                      np.mean(valid_loss), "\tTime: ",
                      time.time() - start)
                f.writelines(
                    f"Epoch: {epoch}\tValidation Loss: {np.mean(valid_loss)} \tTime: {time.time()-start}\n"
                )

                lengths = map(len, x_val_tokens)
                tags = inverse_transform(np.asarray(y_pred),
                                         self.ner_labels_vocab, lengths)
                print('F1: ', f1_score(ner_val_labels, tags), '\t Acc: ',
                      accuracy_score(ner_val_labels, tags))
                f.writelines(
                    f'F1: {f1_score(ner_val_labels, tags)}\t Acc: {accuracy_score(ner_val_labels, tags)}\n'
                )
                print(classification_report(ner_val_labels, tags))
                f.writelines(classification_report(ner_val_labels, tags))

                x_pred = transform_bio_tags_to_annotated_documents(
                    x_val_tokens, tags, X_valid)

                p, r, f1 = annotation_precision_recall_f1score(x_pred, X_valid)
                print("Disease:\tPrecision: ", p, "\tRecall: ", r,
                      "\tF-score: ", f1)
                f.writelines(
                    f"Disease:\tPrecision: {p}\tRecall: {r}\tF-score: {f1}\n")

                if self.hparams['save_best'] and f1 > prev_f1:
                    self.save(self.hparams['file_path'])
                    prev_f1 = f1
                    print('New best: ', prev_f1, '\tSaving model...')
                    f.writelines(f'New best: {prev_f1}\tSaving model...\n')

                print("Best so far: ", prev_f1)
                f.writelines(f"Best so far: {prev_f1}\n")

        return self
def main(paths, params):
    path_to_train_input = paths.training
    path_to_valid_input = paths.develop
    path_to_test= paths.test
    ctd_file = paths.ctd_file
    c2m_file = paths.c2m_file
    toD_mesh = Convert2D(ctd_file, c2m_file)

    sentence_pad = False # Don't pad sentence with begin and end sentence '<s>' and '<\s>

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    writer = SummaryWriter()

    X = BratInput(path_to_train_input)
    X = X.transform()
    X = split_annotated_documents(X)

    X_valid = BratInput(path_to_valid_input)
    X_valid = X_valid.transform()
    X_valid = split_annotated_documents(X_valid)

    X_test = BratInput(path_to_test)
    X_test = X_test.transform()
    X_test = split_annotated_documents(X_test)

    if params.randomize:
        torch.manual_seed(5)
        random.seed(5)
        np.random.seed(5)

    # Obtain MeSH information
    mesh_file = paths.MeSH_file
    disease_file= paths.disease_file
    mesh_graph_file = paths.MeSH_graph_disease
    mesh_folder = paths.MeSH_folder
    mt_folder = paths.multitask_folder


    # read disease file
    with open(disease_file,'r') as f:
        disease_data = f.readlines()

    mesh_dict = read_mesh_file(mesh_file)

    mesh_graph = nx.read_gpickle(mesh_graph_file)
    mesh_graph = mesh_graph.to_undirected()
    scope_text, id2idx_dict, idx2id_dict = mesh_dict_to_tokens(mesh_dict, disease_data)
    node_list = list(idx2id_dict.values())

    # A_HAT metrix for GCN
    if not os.path.exists(os.path.join(mesh_folder, 'a_hat_matrix')):
        a_matrix = get_adjacancy_matrix(mesh_graph, node_list)

        a_matrix = sparse.coo_matrix(a_matrix)
        with open(os.path.join(mesh_folder, 'a_hat_matrix'), 'wb') as f:
            pickle.dump(data, f)
    else:
        with open(os.path.join(mesh_folder, 'a_hat_matrix'), 'rb') as f:
            a_matrix = pickle.load(f)

    i = torch.tensor([a_matrix.row, a_matrix.col], dtype=torch.long, device=device)
    v = torch.tensor(a_matrix.data, dtype=torch.float32, device=device)
    a_hat = torch.sparse.FloatTensor(i, v, torch.Size([len(node_list), len(node_list)])).to(device)

    # Construct usable data format
    x_tr_text, ner_tr_tags, x_tr_tokens = annotated_docs_to_tokens(X, sentence_pad=sentence_pad)
    x_val_text, ner_val_tags, x_val_tokens = annotated_docs_to_tokens(X_valid, sentence_pad=sentence_pad)
    x_test_text, ner_test_tags, x_test_tokens = annotated_docs_to_tokens(X_test, sentence_pad=sentence_pad)

    # elmo embeddings
    options_file = paths.elmo_options
    weight_file = paths.elmo_weights
    ELMO_folder = paths.elmo_folder
    elmo_dim = params.elmo_dim
    elmo = Elmo(options_file, weight_file, 2,dropout=0)
    elmo.to(device)

    with torch.no_grad():
        if not os.path.exists(os.path.join(mt_folder,'text_tr_elmo_split.pkl')):
            text_tr = get_elmo_representation(x_tr_text, elmo, elmo_dim=params.elmo_dim, device=device)
            with open(os.path.join(mt_folder,'text_tr_elmo_split.pkl'),'wb+') as f:
                pickle.dump(text_tr, f)
        else:
            with open(os.path.join(mt_folder,'text_tr_elmo_split.pkl'),'rb+') as f:
                text_tr = pickle.load(f)
        
        if not os.path.exists(os.path.join(mt_folder,'text_val_elmo_split.pkl')):
            text_val = get_elmo_representation(x_val_text, elmo, elmo_dim=params.elmo_dim, device=device)
            with open(os.path.join(mt_folder,'text_val_elmo_split.pkl'),'wb+') as f:
                pickle.dump(text_val, f)
        else:
            with open(os.path.join(mt_folder,'text_val_elmo_split.pkl'),'rb+') as f:
                text_val = pickle.load(f)

        if not os.path.exists(os.path.join(paths.multitask_folder,'text_test_elmo_split.pkl')):
            text_test = get_elmo_representation(x_test_text, elmo, elmo_dim=params.elmo_dim, device=device)
            with open(os.path.join(paths.multitask_folder,'text_test_elmo_split.pkl'),'wb+') as f:
                pickle.dump(text_test, f)
        else:
            with open(os.path.join(paths.multitask_folder,'text_test_elmo_split.pkl'),'rb+') as f:
                text_test = pickle.load(f)

    # NER label vocab
    ner_labels_vocab = Vocabulary(lower=False)
    ner_labels_vocab.add_documents(ner_tr_tags)
    ner_labels_vocab.build()

    # mesh scope embedding
    if not os.path.exists(os.path.join(paths.dump_folder, 'scope_emb.pkl')):
        scope_embedding, _ = get_scope_elmo(elmo, ELMO_folder, scope_text, elmo_dim, idx2id_dict, id2idx_dict, device=device)
        with open(os.path.join(paths.dump_folder, 'scope_emb.pkl'), 'wb') as f:
            pickle.dump(scope_embedding, f)
    else:
        with open(os.path.join(paths.dump_folder, 'scope_emb.pkl'), 'rb') as f:
            scope_embedding = pickle.load(f)
            
    train_el_set = EL_set(X, toD_mesh, id2idx_dict)
    val_el_set = EL_set(X_valid, toD_mesh, id2idx_dict)


    train(paths, params, X, text_tr, ner_tr_tags, train_el_set, X_valid, x_val_tokens, text_val,
            ner_val_tags, val_el_set, ner_labels_vocab, scope_text, scope_embedding, a_hat, mesh_graph, id2idx_dict, idx2id_dict, writer, device=device)
def train(paths,
          params,
          X,
          mesh_dict,
          scope_text,
          id2idx_dict,
          idx2id_dict,
          predictions_tr,
          annotated_docs_tr,
          X_valid,
          predictions_v,
          annotated_docs_v,
          writer=None,
          device=torch.device('cpu')):
    options_file = paths.elmo_options
    weight_file = paths.elmo_weights
    ELMO_folder = paths.elmo_folder
    elmo_dim = params.elmo_dim
    elmo = Elmo(options_file, weight_file, 2, dropout=0)
    elmo.to(device)

    # re-encode nodes
    scope_elmo_emb, _ = get_scope_elmo(elmo,
                                       ELMO_folder,
                                       scope_text,
                                       elmo_dim,
                                       idx2id_dict,
                                       id2idx_dict,
                                       device=device)
    scope_elmo_emb = re_encode(paths.embedding,
                               scope_elmo_emb,
                               idx2id_dict,
                               device=device)

    # format trainable data
    x_data, texts = construct_data(X,
                                   annotated_docs_tr,
                                   predictions_tr,
                                   scope_text,
                                   id2idx_dict,
                                   paths.ctd_file,
                                   paths.c2m_file,
                                   use_ELMO=params.use_elmo,
                                   elmo_model=elmo,
                                   elmo_dim=params.elmo_dim,
                                   device=device)

    x_v_data, _ = construct_data(X_valid,
                                 annotated_docs_v,
                                 predictions_v,
                                 scope_text,
                                 id2idx_dict,
                                 paths.ctd_file,
                                 paths.c2m_file,
                                 use_ELMO=params.use_elmo,
                                 elmo_model=elmo,
                                 elmo_dim=params.elmo_dim,
                                 device=device)

    word_vocab, char_vocab = get_text_vocab([texts])
    char_dict = get_char_dict(mesh_dict, char_vocab)

    params_dict1 = {
        'batch_size': params.batch_size,
        'shuffle': True,
        'num_workers': 12
    }

    EL_dataset = EL_sequence(x_data, scope_elmo_emb, id2idx_dict,
                             copy.deepcopy(char_vocab))
    training_generator = DataLoader(EL_dataset,
                                    collate_fn=padding_EL,
                                    **params_dict1)

    model = EL_model(elmo_dim, elmo_dim, scope_elmo_emb)
    model.to(device)

    # sim_model = EL_similarity(200, len(char_vocab))
    # sim_model.to(device)

    optimizer1 = torch.optim.Adam(model.parameters(),
                                  lr=params.lr,
                                  weight_decay=0.0)
    # optimizer2 = torch.optim.Adam(sim_model.parameters(), lr=params.lr, weight_decay=0.0)

    epochs = params.num_epochs
    batch_size = params.batch_size
    train_acc = 0
    val_acc = 0
    best_model1 = None
    best_model2 = None
    ep = 0
    with open(os.path.join(params.output), 'w') as f:
        for epoch in range(epochs):
            np.random.shuffle(x_data)
            x_data_ = x_data

            training_loss1 = 0
            # training_loss2 = 0
            count = 0
            start_time = time()
            model.train()
            # sim_model.train()

            tr_labels, tr_mentions, pred_index, = [], [], []
            for t, x, _, y, mask in minibatch(x_data_,
                                              scope_elmo_emb,
                                              char_vocab=char_vocab,
                                              id2idx=id2idx_dict,
                                              use_elmo=True,
                                              elmo_model=elmo,
                                              n_samples=0,
                                              batch_size=batch_size,
                                              device=device):
                # for t, x, _, y, mask in training_generator:
                x, y, mask = x.to(device), y.to(device), mask.to(device)
                optimizer1.zero_grad()

                mask_ = mask.unsqueeze(2).expand(-1, -1, x.shape[2])
                x = mask_ * x
                x = torch.mean(x, dim=1)
                x = model(x)

                loss = nn.functional.cross_entropy(x, y)

                loss.backward()
                optimizer1.step()
                training_loss1 += loss.item()
                count += 1

                z = nn.functional.softmax(x, dim=1)
                _, index_sorted = torch.sort(z, descending=True)

                tr_labels.extend(y)
                pred_index.extend(index_sorted)
                tr_mentions.extend(t)

            writer.add_scalars('training', {'loss1': training_loss1 / count},
                               global_step=epoch)
            print(
                f'Epoch: {epoch}\tLoss1: {training_loss1/count}\tTime: {time()-start_time}'
            )
            f.writelines(
                f'Epoch: {epoch}\tLoss1: {training_loss1/count}\tTime: {time()-start_time}\n'
            )

            start_time = time()
            with torch.no_grad():
                model.eval()
                pred_label, label = [], []
                for t, x, y, mask in minibatch_val(x_data,
                                                   batch_size=batch_size):
                    x, mask = x.to(device), mask.to(device).float()

                    mask = mask.unsqueeze(2).expand(-1, -1, x.shape[2])
                    x = mask * x
                    x = torch.mean(x, dim=1)
                    x = model(x)

                    x = nn.functional.softmax(x, dim=1)
                    _, max_idx = torch.max(x, dim=1)
                    label.extend(y)
                    _ = [
                        pred_label.append(idx2id_dict[i.item()])
                        for i in max_idx
                    ]

                # print(classification_report(label, pred_label))
                acc = accuracy_score(label, pred_label)
                print(f'Train Acc: {acc}\tTime: {time()-start_time}')
                # f.writelines(classification_report(label, pred_label))
                f.writelines(
                    f'\nTrain Acc: {acc}\tTime: {time()-start_time}\n')

                if acc > train_acc:
                    train_acc = acc

                print('Best train acc: ', train_acc)
                f.writelines(f'Best train acc: {train_acc}\n')

                pred_label, label = [], []
                for t, x, y, mask in minibatch_val(x_v_data,
                                                   batch_size=batch_size):
                    x, mask = x.to(device), mask.to(device).float()

                    mask = mask.unsqueeze(2).expand(-1, -1, x.shape[2])
                    x = mask * x
                    x = torch.mean(x, dim=1)
                    x = model(x)

                    x = nn.functional.softmax(x, dim=1)
                    _, max_idx = torch.max(x, dim=1)
                    label.extend(y)
                    _ = [
                        pred_label.append(idx2id_dict[i.item()])
                        for i in max_idx
                    ]

                # print(classification_report(label, pred_label))
                acc = accuracy_score(label, pred_label)
                print(f'Valid Acc: {acc}\tTime: {time()-start_time}')
                # f.writelines(print(classification_report(label, pred_label)))
                f.writelines(
                    f'\nValid Acc: {acc}\tTime: {time()-start_time}, epoch: {epoch}\n'
                )

                if acc > val_acc:
                    val_acc = acc
                    print('Updating best model for acc: ', acc)
                    f.writelines(f'Updating model at epoch: {epoch}\n')
                    best_model1 = model
                    # best_model2 = sim_model
                    ep = epoch

                print(f'Best val acc: {val_acc}, epoch: {ep}')
                f.writelines(f'Best val acc: {val_acc}, epoch: {ep}\n')

    save(paths, params, scope_text, scope_elmo_emb, id2idx_dict, idx2id_dict,
         char_vocab, char_dict, best_model1, best_model2)
def test(paths, params, X_test, annotated_docs_test, predictions_test):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    [
        scope_text, scope_elmo_emb, id2idx_dict, idx2id_dict, char_vocab,
        char_dict, old_params, model1, model2
    ] = load(paths)
    options_file = paths.elmo_options
    weight_file = paths.elmo_weights
    elmo = Elmo(options_file, weight_file, 2, dropout=0)
    elmo.to(device)

    unique_id = get_unique_id(paths.file1, paths.file2, paths.ctd_file,
                              paths.c2m_file)
    x_test_data, _ = construct_data(X_test,
                                    annotated_docs_test,
                                    predictions_test,
                                    scope_text,
                                    id2idx_dict,
                                    paths.ctd_file,
                                    paths.c2m_file,
                                    use_ELMO=params.use_elmo,
                                    elmo_model=elmo,
                                    elmo_dim=params.elmo_dim,
                                    device=device)

    model1.eval()
    # model2.eval()
    with torch.no_grad():
        pred_index, pred_label, labels, rank, reciprocal_rank, t_mentions =[], [], [], [], [], []
        for t, x, y, mask in minibatch_val(x_test_data,
                                           char_vocab=char_vocab,
                                           batch_size=params.batch_size):
            x, mask = x.to(device), mask.to(device).float()

            mask = mask.unsqueeze(2).expand(-1, -1, x.shape[2])
            x = mask * x
            x = torch.mean(x, dim=1)
            x = model1(x)

            x = nn.functional.softmax(x, dim=1)
            _, max_idx = torch.max(x, dim=1)
            labels.extend(y)
            _ = [pred_label.append(idx2id_dict[i.item()]) for i in max_idx]

            _, index_sorted = torch.sort(x, descending=True)
            pred_index.extend(index_sorted)
            t_mentions.extend(t)

        sorted_list = []
        zeroshot_rank, zeroshot_rrank = [], []
        pred_2, pred_5, pred_10, pred_15, pred_30 = [], [], [], [], []
        for idx, item in enumerate(pred_index):
            id_sorted = [idx2id_dict[i.item()] for i in item]
            sorted_list.append(id_sorted)

            if labels[idx] in id_sorted:
                rank.append(id_sorted.index(labels[idx]) + 1)
                reciprocal_rank.append(1 / (id_sorted.index(labels[idx]) + 1))

                if labels[idx] in unique_id:
                    zeroshot_rank.append(id_sorted.index(labels[idx]) + 1)
                    zeroshot_rrank.append(1 /
                                          (id_sorted.index(labels[idx]) + 1))
            else:
                print(f"ID {labels[idx]} not found")

            if labels[idx] in id_sorted[0:2]:
                pred_2.append(labels[idx])
            else:
                pred_2.append(id_sorted[0])
            if labels[idx] in id_sorted[0:5]:
                pred_5.append(labels[idx])
            else:
                pred_5.append(id_sorted[0])
            if labels[idx] in id_sorted[0:10]:
                pred_10.append(labels[idx])
            else:
                pred_10.append(id_sorted[0])
            if labels[idx] in id_sorted[0:15]:
                pred_15.append(labels[idx])
            else:
                pred_15.append(id_sorted[0])
            if labels[idx] in id_sorted[0:30]:
                pred_30.append(labels[idx])
            else:
                pred_30.append(id_sorted[0])

        print(classification_report(labels, pred_label))
        print(f'Mean Reciprocal Rank: {np.mean(reciprocal_rank)}')
        acc = accuracy_score(labels, pred_label)
        print(f'Test Acc@1: {acc}')
        acc = accuracy_score(labels, pred_2)
        print(f'Test Acc@2: {acc}')
        acc = accuracy_score(labels, pred_5)
        print(f'Test Acc@5: {acc}')
        acc = accuracy_score(labels, pred_10)
        print(f'Test Acc@10: {acc}')
        acc = accuracy_score(labels, pred_15)
        print(f'Test Acc@15: {acc}')
        acc = accuracy_score(labels, pred_30)
        print(f'Test Acc@30: {acc}')

        print(f'Zero shot MRR: {np.mean(zeroshot_rrank)}')