Exemplo n.º 1
0
    def test_elmo_bilm_can_handle_higher_dimensional_input_with_cache(self):
        sentences = [[u"This", u"is", u"a", u"sentence"],
                     [u"Here", u"'s", u"one"], [u"Another", u"one"]]
        vocab, tensor = self.get_vocab_and_both_elmo_indexed_ids(sentences)
        words_to_cache = list(
            vocab.get_token_to_index_vocabulary(u"tokens").keys())
        elmo_bilm = Elmo(self.options_file,
                         self.weight_file,
                         1,
                         vocab_to_cache=words_to_cache)
        elmo_bilm.eval()

        individual_dim = elmo_bilm(tensor[u"character_ids"], tensor[u"tokens"])
        elmo_bilm = Elmo(self.options_file,
                         self.weight_file,
                         1,
                         vocab_to_cache=words_to_cache)
        elmo_bilm.eval()

        expanded_word_ids = torch.stack([tensor[u"tokens"] for _ in range(4)],
                                        dim=1)
        expanded_char_ids = torch.stack(
            [tensor[u"character_ids"] for _ in range(4)], dim=1)
        expanded_result = elmo_bilm(expanded_char_ids, expanded_word_ids)
        split_result = [
            x.squeeze(1) for x in torch.split(
                expanded_result[u"elmo_representations"][0], 1, dim=1)
        ]
        for expanded in split_result:
            numpy.testing.assert_array_almost_equal(
                expanded.data.cpu().numpy(),
                individual_dim[u"elmo_representations"][0].data.cpu().numpy())
Exemplo n.º 2
0
def embed_corpus_with_elmo(corpus_name="ag_news",
                           document_size=4000,
                           language_model="elmo"):
    from allennlp.modules.elmo import Elmo, batch_to_ids
    # code from https://github.com/allenai/allennlp/issues/2245
    options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
    weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

    model = Elmo(options_file, weight_file, 1, dropout=0)
    model.eval()
    model = model.to(torch.device("cuda"))
    tokens = []
    embeddings = []
    corpus = get_corpus(corpus_name, document_size)
    for doc in tqdm(corpus):
        token, ids = doc.split(), batch_to_ids([doc.split()])
        ids = ids.cuda(torch.device('cuda'))
        with torch.no_grad():
            hidden_states = model(ids)
        embedding = hidden_states["elmo_representations"][0][0]
        embedding = embedding.detach().cpu().numpy()
        tokens.append(token)
        embeddings.append(embedding)
    with open(f"{corpus_name}.{language_model}.pk", "wb") as f:
        pickle.dump({
            "tokens": tokens,
            "embeddings": embeddings
        },
                    f,
                    protocol=4)
    def elmo_encode(self, data, __id2word):
        """
        get the id2word from vocab, then convert to id
        from allennlp.modules.elmo import Elmo, batch_to_ids
        batch_to_id fills to the max sentence length, which could be less than desired
        So further fill it to get to the max sent length
        """
        data_text = [self.glove_tokenizer(x, __id2word) for x in data]

        with torch.no_grad():
            elmo = Elmo(options_file, weight_file, 2, dropout=0).cuda()
            elmo.eval()
            character_ids = batch_to_ids(data_text).cuda()

            row_num = character_ids.shape[0]
            elmo_dim = self.elmo_dim

            if torch.sum(character_ids) != 0:
                elmo_emb = elmo(character_ids)['elmo_representations']
                elmo_emb = (elmo_emb[0] + elmo_emb[1]) / 2  # avg of two layers
            else:
                elmo_emb = torch.zeros([row_num, self.sent_pad_len, elmo_dim],
                                       dtype=torch.float)

        sent_len = elmo_emb.shape[1]

        if sent_len < self.sent_pad_len:
            fill_sent_len = self.sent_pad_len - sent_len
            # create a bunch of 0's to fill it up
            filler = torch.zeros([row_num, fill_sent_len, elmo_dim],
                                 dtype=torch.float)
            elmo_emb = torch.cat((elmo_emb, filler.cuda()), dim=1)
        return elmo_emb.cuda()
Exemplo n.º 4
0
class getElmo(nn.Module):
    def __init__(self, layer=2, dropout=0, out_dim=100, gpu=True):
        super(getElmo, self).__init__()
        options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
        weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
        self.dropout = dropout
        self.gpu = gpu
        self.Elmo = Elmo(options_file, weight_file, layer, dropout=dropout)
        self.Elmo.eval()
        self.layers2one = nn.Linear(
            layer, 1).cuda() if self.gpu else nn.Linear(layer, 1)
        self.optLinear = nn.Linear(
            1024, out_dim).cuda() if self.gpu else nn.Linear(1024, out_dim)

    def forward(self, texts):
        word_idxs = batch_to_ids(texts).cuda() if self.gpu else batch_to_ids(
            texts)
        elmo_embs = self.Elmo.forward(word_idxs)
        elmo_reps = torch.stack(elmo_embs['elmo_representations'],
                                dim=-1).cuda() if self.gpu else torch.stack(
                                    elmo_embs['elmo_representations'], dim=-1)
        elmo_decrease_layer = self.layers2one(elmo_reps).squeeze()
        elmo_fit_hidden = self.optLinear(elmo_decrease_layer)
        mask = elmo_embs['mask']

        return elmo_fit_hidden, mask
Exemplo n.º 5
0
class ElmoEmbedding:
    def __init__(self, dim):
        if dim == 2048:
            options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
            weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
        elif dim == 512:
            options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
            weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
        self.dim = dim
        self.elmo = Elmo(options_file, weight_file, 2, dropout=0)
        if func.gpu_available():
            self.elmo = self.elmo.cuda()
        self.elmo.eval()
        self.load()


    def save(self):
        pass


    def load(self):
        self.cache = DiskDict(f'./generate/elmo.{self.dim}.cache')


    def convert(self, sentences):
        not_hit = set()
        for sent in sentences:
            key = self.make_key(sent)
            if key not in self.cache:
                not_hit.add(key)
        not_hit = list(not_hit)
        if not_hit:
            embeddings, masks = self.convert_impl([self.make_sentence(key) for key in not_hit])
            for key, embedding, mask in zip(not_hit, torch.unbind(embeddings), torch.unbind(masks)):
                embedding = embedding[:mask.sum()]
                self.cache[key] = embedding.tolist()
        embeddings = [func.tensor(self.cache[self.make_key(sent)]) for sent in sentences]
        mlen = max([e.shape[0] for e in embeddings])
        embeddings = [func.pad_zeros(e, mlen, 0) for e in embeddings]
        embeddings = torch.stack(embeddings)
        assert embeddings.requires_grad == False
        return embeddings


    def make_key(self, sent):
        return '$$'.join(sent)


    def make_sentence(self, key):
        return key.split('$$')


    def convert_impl(self, sentences):
        character_ids = func.tensor(batch_to_ids(sentences))
        m = self.elmo(character_ids)
        embeddings = m['elmo_representations']
        embeddings = torch.cat(embeddings, -1)
        mask = m['mask']
        return embeddings, mask
class FeatureExtractor:
    def __init__(self, cfg):
        self.cfg = cfg

        options_file = self.cfg.elmo['options_file']
        weights_file = self.cfg.elmo['weights_file']
        self.encoder = Elmo(options_file, weights_file, 1, dropout=0)
        if self.cfg.use_gpu:
            self.encoder.cuda()
        self.encoder.eval()
        print(
            f"Elmo initialized with options:\n{options_file}\n{weights_file}.",
            end='\n\n')

    def process(self, dialog_paths):
        """ Write features to each corresponding dialog file."""
        for dialog_path in dialog_paths:
            print(f"Processing {dialog_path.stem}...")
            with dialog_path.open("r") as fin:
                dialog_annotations = json.load(fin)
                features = []
                for dialog_annotation in dialog_annotations:
                    features.extend(self.process_dialog(dialog_annotation))

                dialog_feature_path = dialog_path.with_suffix(".pt")
                torch.save({"features": features}, dialog_feature_path)

    def process_dialog(self, dialog):
        """Elmo representation is extracted for each candidate in a turn."""
        features = []
        for turn_idx, turn in enumerate(dialog["turns"]):
            tokens = [turn["tokens"]]
            token_ids = batch_to_ids(tokens)
            if self.cfg.use_gpu:
                token_ids = token_ids.cuda()
            embeddings = self.encoder(
                token_ids)["elmo_representations"][0].detach().cpu().data
            reps = []
            for _, candidates in turn["candidates"].items():
                for candidate in candidates:
                    rep = embeddings[0, candidate[0], :].detach()
                    reps.append(rep)
            features.append(reps)
        return features
Exemplo n.º 7
0
    def test_elmo_bilm_can_handle_higher_dimensional_input_with_cache(self):
        sentences = [["This", "is", "a", "sentence"],
                     ["Here", "'s", "one"],
                     ["Another", "one"]]
        vocab, tensor = self.get_vocab_and_both_elmo_indexed_ids(sentences)
        words_to_cache = list(vocab.get_token_to_index_vocabulary("tokens").keys())
        elmo_bilm = Elmo(self.options_file, self.weight_file, 1, vocab_to_cache=words_to_cache)
        elmo_bilm.eval()

        individual_dim = elmo_bilm(tensor["character_ids"], tensor["tokens"])
        elmo_bilm = Elmo(self.options_file, self.weight_file, 1, vocab_to_cache=words_to_cache)
        elmo_bilm.eval()

        expanded_word_ids = torch.stack([tensor["tokens"] for _ in range(4)], dim=1)
        expanded_char_ids = torch.stack([tensor["character_ids"] for _ in range(4)], dim=1)
        expanded_result = elmo_bilm(expanded_char_ids, expanded_word_ids)
        split_result = [x.squeeze(1) for x in torch.split(expanded_result["elmo_representations"][0], 1, dim=1)]
        for expanded in split_result:
            numpy.testing.assert_array_almost_equal(expanded.data.cpu().numpy(),
                                                    individual_dim["elmo_representations"][0].data.cpu().numpy())
Exemplo n.º 8
0
class FeatureExtractor():
    def __init__(self, selfelmo_weight=None, elmo_option=None):
        '''Only work in CPU'''
        super().__init__()
        pre = 'datasets/elmo_pretrained/'
        weights_file = pre + 'elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'
        options_file = pre + 'option.json'
        self.encoder = Elmo(options_file, weights_file, 1, dropout=0)
        if torch.cuda.is_available(): self.encoder = self.encoder.cuda()
        self.encoder.eval()
        print(
            f"Elmo initialized with options:\n{options_file}\n{weights_file}.",
            end='\n\n')

    def __call__(self, tokens):
        """Elmo representation is extracted for each candidate in a turn."""
        token_ids = batch_to_ids(tokens)
        if torch.cuda.is_available(): token_ids = token_ids.cuda()
        embeddings = self.encoder(
            token_ids)["elmo_representations"][0].detach().cpu().data
        return embeddings
Exemplo n.º 9
0
class ElmoWrapper(nn.Module):
    def __init__(self, args):

        super(ElmoWrapper, self).__init__()

        options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
        weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
        self.elmo = Elmo(options_file, weight_file, 2,
                         dropout=0.0).to(args.device)  # 2 layers
        self.elmo.eval()

    def forward(self, tokenid):
        '''
        > tokenid (batch, seqlen, .) int
        < emb (batch, seqlen, d_lang)
        < mask (batch, seqlen) bool
        '''

        with torch.no_grad():
            emb = self.elmo(tokenid)['elmo_representations'][-1]
        mask = tokenid[:, :, 0] != 0
        return emb, mask
Exemplo n.º 10
0
def calculate_post_elmo_embeddings(posts: Dict[str, Post],
                                   max_sentence_length: int,
                                   batch_size: int,
                                   scalar_mix_parameters: List[float],
                                   device: torch.device) \
        -> Dict[str, torch.Tensor]:
    """Calculate ELMo embeddings of all posts in the dataset.

    Calculating these embeddings one time before training the actual models
    allows for extremely fast training later. The downsides are that we can't
    propagate gradients through the embeddings, but fine-tuning these would
    probably lead to be overfitting, since our dataset is very small.
    Additionally, we also can't learn the scalar_mix_parameters, but since
    training is so much faster, adjusting these by hand should be sufficient.

    Since we are going to load the entire dataset into GPU memory later anyways,
    we keep the embeddings in GPU memory here already.

    Args:
        posts: A dictionary mapping post IDs to their respective posts. Load
            this with `src.dataset.load_posts()`.
        max_sentence_length: Number of tokens after which sentences will be
            truncated.
        batch_size: Batch size for calculating the ELMo embeddings.
        scalar_mix_parameters: Parameters for mixing the different ELMo layers.
            See the paper for details on this.
        device: Device to execute on.

    Returns:
        A dictionary mapping post IDs to their respective ELMo embedding in a
        PyTorch tensor. Each tensor will have shape
        `(num_elmo_dimensions, max_sentence_length)`.
    """

    print('Calculating post embeddings...')
    time_before = time()

    elmo = Elmo(ELMO_OPTIONS_FILE,
                ELMO_WEIGHTS_FILE,
                num_output_representations=1,
                dropout=0,
                requires_grad=False,
                do_layer_norm=False,
                scalar_mix_parameters=scalar_mix_parameters).to(device)
    elmo.eval()

    post_embeddings = {}
    batch_ids = []
    # Add a dummy sentence with max_sentence_length to each batch to enforce
    # that each batch of embeddings has the same shape. `batch_to_id()` and
    # `elmo()` take care of zero padding shorter sentences for us.
    batch_texts = [['' for _ in range(max_sentence_length)]]
    for i, post in enumerate(posts.values()):
        batch_ids.append(post.id)
        batch_texts.append(post.text[:max_sentence_length])

        if not i % batch_size or i == len(posts) - 1:
            batch_character_ids = batch_to_ids(batch_texts).to(device)
            batch_texts = [['' for _ in range(max_sentence_length)]]

            # - [0] to select first output representation (there is only one
            #   because of `num_output_representations=1` at `elmo` creation.
            # - [1:] to ignore dummy sentence added at the start.
            batch_embeddings = \
                elmo(batch_character_ids)['elmo_representations'][0][1:]
            batch_embeddings = batch_embeddings.split(split_size=1, dim=0)
            del batch_character_ids  # Free up memory sooner.

            for post_id, post_embedding in zip(batch_ids, batch_embeddings):
                post_embedding.squeeze_(dim=0)
                post_embedding.transpose_(0, 1)
                post_embeddings[post_id] = post_embedding
            batch_ids = []

    time_after = time()
    print('  Took {:.2f}s.'.format(time_after - time_before))

    return post_embeddings
Exemplo n.º 11
0
def concatenation_encode(data_path):
    file_name_list = [
        "negation_detection.txt", "negation_variant.txt",
        "clause_relatedness.txt", "argument_sensitivity.txt",
        "fixed_point_inversion.txt"
    ]
    accuracy_function = [
        normal_accuracy, negation_variant_accuracy, normal_accuracy,
        normal_accuracy, normal_accuracy
    ]
    file_path_list = [os.path.join(data_path, ele) for ele in file_name_list]
    options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
    weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
    elmo_model = Elmo(options_file, weight_file, 1)
    bert_model = BertModel.from_pretrained('bert-base-uncased')
    elmo_model.eval()
    bert_model.eval()
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    for idx, path in enumerate(file_path_list):
        average_pooling_tensor = None
        max_pooling_tensor = None
        average_bert_tensor = None
        max_bert_tensor = None
        sentences = sentences_unfold(path, delimiter="\t")
        dataset = TextIndexDataset(sentences, tokenizer, True)
        data_loader = DataLoader(dataset,
                                 batch_size=48,
                                 num_workers=0,
                                 collate_fn=dataset.collate_fn_one2one)
        for ids, masks, elmo_ids in data_loader:
            masks = masks.float()
            if torch.cuda.is_available():
                ids = ids.cuda()
                masks = masks.cuda()
                elmo_ids = elmo_ids.cuda()
                bert_model = bert_model.cuda()
                elmo_model = elmo_model.cuda()
            with torch.no_grad():
                encoded_bert_layers, _ = bert_model(
                    ids, attention_mask=masks, output_all_encoded_layers=False)
                elmo_dict = elmo_model(elmo_ids)
                elmo_representations = elmo_dict["elmo_representations"][
                    0]  # type: torch.Tensor
                elmo_mask = elmo_dict["mask"]
                elmo_mask = elmo_mask.float()
                concatenated_layers = torch.cat(
                    (encoded_bert_layers, elmo_representations), dim=2)

                average_elmo = get_average_pooling(elmo_representations,
                                                   elmo_mask)
                max_elmo = get_max_pooling(elmo_representations)

                average_embeddings = get_average_pooling(
                    concatenated_layers, masks)
                max_pooling_embeddings = get_max_pooling(concatenated_layers)
                average_pooling_tensor = average_embeddings if average_pooling_tensor is None else torch.cat(
                    [average_pooling_tensor, average_embeddings], dim=0)
                max_pooling_tensor = max_pooling_embeddings if max_pooling_tensor is None else torch.cat(
                    [max_pooling_tensor, max_pooling_embeddings], dim=0)
                average_bert_tensor = average_elmo if average_bert_tensor is None else torch.cat(
                    [average_bert_tensor, average_elmo], dim=0)
                max_bert_tensor = max_elmo if max_bert_tensor is None else torch.cat(
                    [max_bert_tensor, max_elmo], dim=0)

        average_pooling_result = output_results(
            average_pooling_tensor.cpu().numpy(),
            calculate_accuracy=accuracy_function[idx])
        max_pooling_result = output_results(
            max_pooling_tensor.cpu().numpy(),
            calculate_accuracy=accuracy_function[idx])
        average_bert_result = output_results(
            average_bert_tensor.cpu().numpy(),
            calculate_accuracy=accuracy_function[idx])
        max_bert_result = output_results(
            max_bert_tensor.cpu().numpy(),
            calculate_accuracy=accuracy_function[idx])

        print("Result of average pooling bert on {0} dataset is: --------".
              format(file_name_list[idx]))
        print("\t& ".join(average_bert_result) + """\\""")
        print("Result of max pooling bert on {0} dataset is: --------".format(
            file_name_list[idx]))
        print("\t& ".join(max_bert_result) + """\\""")

        print(
            "Result of average pooling  concatenation on {0} dataset is: --------"
            .format(file_name_list[idx]))
        print("\t& ".join(average_pooling_result) + """\\""")
        print(
            "Result of max pooling concatenation on {0} dataset is: --------".
            format(file_name_list[idx]))
        print("\t& ".join(max_pooling_result) + """\\""")
class PhraseEmbeddingSent(torch.nn.Module):

    def __init__(self, cfg, phrase_embed_dim=1024, bidirectional=True):
        super(PhraseEmbeddingSent, self).__init__()

        self.device = torch.device('cuda')
        self.bidirectional = bidirectional

        vocab_file = open(cfg.MODEL.VG.VOCAB_FILE)
        self.vocab = json.load(vocab_file)
        vocab_file.close()
        add_vocab = ['relate', 'butted']
        self.vocab.extend(add_vocab)
        self.vocab_to_id = {v: i + 1 for i, v in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab) + 1

        phr_vocab_file = open(cfg.MODEL.VG.VOCAB_PHR_FILE)
        self.phr_vocab = json.load(phr_vocab_file)
        self.phr_vocab_to_id = {v:i+1 for i, v in enumerate(self.phr_vocab)}
        self.phr_vocab_size = len(self.phr_vocab) + 1


        self.embed_dim = phrase_embed_dim

        if self.bidirectional:
            self.hidden_dim = phrase_embed_dim // 2
        else:
            self.hidden_dim = self.embed_dim

        if cfg.MODEL.VG.USING_ELMO:
            options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
            weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
            self.elmo = Elmo(options_file, weight_file, 2, dropout=0, requires_grad=False)
            self.elmo.eval()
        else:
            self.enc_embedding = nn.Embedding(num_embeddings=self.vocab_size,
                                              embedding_dim=self.embed_dim,
                                              padding_idx=0, sparse=False)

        self.sent_rnn = nn.GRU(input_size=self.embed_dim, hidden_size=self.hidden_dim, num_layers=1,
                               batch_first=True, dropout=0, bidirectional=self.bidirectional, bias=True)

        if cfg.MODEL.VG.USING_DET_KNOWLEDGE:
            with open(cfg.MODEL.VG.GLOVE_DICT_FILE, 'rb') as load_f:
                self.glove_embedding = pickle.load(load_f)  ## dict, which contain word embedding.


        if cfg.SOLVER.INIT_PARA:
            self.init_para()

    def init_para(self, ):

        # Initialize LSTM Weights and Biases
        for layer in self.sent_rnn._all_weights:
            for param_name in layer:
                if 'weight' in param_name:
                    weight = getattr(self.sent_rnn, param_name)
                    nn.init.xavier_normal_(weight.data)
                else:
                    bias = getattr(self.sent_rnn, param_name)
                    # bias.data.zero_()
                    nn.init.uniform_(bias.data, a=-0.01, b=0.01)
        nn.init.uniform_(self.enc_embedding.weight.data, a=-0.01, b=0.01)


    @staticmethod
    def filtering_phrase(phrases, all_phrase):
        phrase_valid = []
        for phr in phrases:
            if phr['phrase_id'] in all_phrase:
                phrase_valid.append(phr)
        return phrase_valid

    def forward(self, all_sentences, all_phrase_ids, all_sent_sgs):

        batch_phrase_ids = []
        batch_phrase_types = []
        batch_phrase_embed = []
        batch_phrase_len = []
        batch_phrase_dec_ids = []
        batch_phrase_mask = []
        batch_decoder_word_embed = []
        batch_glove_phrase_embed = []

        for idx, sent in enumerate(all_sentences):

            seq = sent['sentence'].lower()
            phrases = sent['phrases']
            phrase_ids = []
            phrase_types = []
            input_phr = []
            lengths = []

            valid_phrases = self.filtering_phrase(phrases, all_phrase_ids[idx])
            tokenized_seq = seq.split(' ')
            seq_enc_ids = [[self.vocab_to_id[w] for w in tokenized_seq]]

            """ Extract the word embedding and feed into sent_rnn"""
            if cfg.MODEL.VG.USING_ELMO:
                input_seq_idx = batch_to_ids([tokenized_seq]).to(self.device)
                seq_embed_b = self.elmo(input_seq_idx)['elmo_representations'][1] ## 1*L*1024
                seq_embed, hn = self.sent_rnn(seq_embed_b)
            else:

                seq_embed_b = self.enc_embedding(torch.as_tensor(seq_enc_ids).long().to(self.device)) # 1*L*1024
                seq_embed, hn = self.sent_rnn(seq_embed_b)

            # tokenized the phrase
            max_len = np.array([len(phr['phrase'].split(' ')) for phr in valid_phrases]).max()
            phrase_dec_ids = np.zeros((len(valid_phrases), max_len+1)) ## to predict end token
            phrase_mask = np.zeros((len(valid_phrases), max_len+1)) ## to predict the "end" token


            phrase_decoder_word_embeds = torch.zeros(len(valid_phrases), max_len, seq_embed.shape[-1]).to(self.device)  ##
            phrase_embeds = []

            phrase_glove_embedding = []
            for pid, phr in enumerate(valid_phrases):
                phrase_ids.append(phr['phrase_id'])
                phrase_types.append(phr['phrase_type'])
                tokenized_phr = phr['phrase'].lower().split(' ')
                phr_len = len(tokenized_phr)
                start_ind = phr['first_word_index']

                word_glove_embedding = []
                for wid, word in enumerate(tokenized_phr):
                    phrase_dec_ids[pid][wid] = self.phr_vocab_to_id[word]

                    if cfg.MODEL.VG.USING_DET_KNOWLEDGE:
                        phr_glo_vec = self.glove_embedding.get(word)
                        if phr_glo_vec is not None:
                            word_glove_embedding.append(phr_glo_vec)

                if cfg.MODEL.VG.USING_DET_KNOWLEDGE:
                    if len(word_glove_embedding) == 0:
                        word_glove_embedding = 0 * torch.as_tensor(self.glove_embedding.get('a')).float().unsqueeze(0)  ## 1*300
                    else:
                        word_glove_embedding = torch.as_tensor(np.array(word_glove_embedding)).float().mean(0, keepdim=True)
                    phrase_glove_embedding.append(word_glove_embedding)

                phrase_mask[pid][:phr_len+1] = 1
                phrase_decoder_word_embeds[pid, :phr_len, :] =  phrase_decoder_word_embeds[pid, :phr_len, :] + seq_embed_b[0][start_ind:start_ind+phr_len]

                if cfg.MODEL.VG.PHRASE_SELECT_TYPE == 'Sum':
                    phrase_embeds.append(seq_embed[[0], start_ind:start_ind+phr_len].sum(1))  # average the embedding
                elif cfg.MODEL.VG.PHRASE_SELECT_TYPE == 'Mean':
                    phrase_embeds.append(seq_embed[[0], start_ind:start_ind+phr_len].mean(1))

            phrase_embeds = torch.cat(phrase_embeds, dim=0)
            phrase_mask = torch.as_tensor(phrase_mask).float().to(self.device)
            if cfg.MODEL.VG.USING_DET_KNOWLEDGE:
                phrase_glove_embedding = torch.cat(phrase_glove_embedding, dim=0).to(self.device)  ## numP, 300


            batch_phrase_ids.append(phrase_ids)
            batch_phrase_types.append(phrase_types)
            batch_phrase_embed.append(phrase_embeds)
            batch_phrase_len.append(lengths)
            batch_phrase_dec_ids.append(phrase_dec_ids)
            batch_phrase_mask.append(phrase_mask)
            batch_decoder_word_embed.append(phrase_decoder_word_embeds)
            batch_glove_phrase_embed.append(phrase_glove_embedding)

        return batch_phrase_ids, batch_phrase_types, batch_phrase_embed, batch_phrase_len, \
               batch_phrase_dec_ids, batch_phrase_mask, batch_decoder_word_embed, batch_glove_phrase_embed
Exemplo n.º 13
0
class PhraseEmbeddingSentElmo(torch.nn.Module):
    def __init__(self, cfg, phrase_embed_dim=1024, bidirectional=False):
        super(PhraseEmbeddingSentElmo, self).__init__()

        self.hidden_dim = phrase_embed_dim
        self.phrase_select_type = cfg.MODEL.VG.PHRASE_SELECT_TYPE
        self.bidirectional = bidirectional
        self.hidden_dim = phrase_embed_dim if not self.bidirectional else phrase_embed_dim // 2

        options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
        weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

        # Compute two different representation for each token.
        # Each representation is a linear weighted combination for the
        # 3 layers in ELMo (i.e., charcnn, the outputs of the two BiLSTM))
        self.elmo = Elmo(options_file,
                         weight_file,
                         2,
                         dropout=0,
                         requires_grad=False)
        self.elmo.eval()
        self.seq_rnn = nn.GRU(input_size=1024,
                              hidden_size=self.hidden_dim,
                              num_layers=1,
                              bias=True,
                              batch_first=True,
                              dropout=0,
                              bidirectional=bidirectional)
        # self.rel_rnn = nn.GRU(input_size=1024, hidden_size=self.hidden_dim, num_layers=1,
        #                   bias=True, batch_first=True, dropout=0,
        #                   bidirectional=bidirectional)
        # self.pos_enc = PositionalEncoder(d_model=1024)

        # if self.intra_language_relation_on:
        #     self.rel_rnn = nn.GRU(input_size=1024, hidden_size=self.hidden_dim, num_layers=1,
        #                           bias=True, batch_first=True, dropout=0, bidirectional=True)

    def forward(self, all_sentences, all_phrase_ids, all_sent_sgs, device_id):

        batch_phrase_ids = []
        batch_phrase_types = []
        batch_phrase_embed = []
        batch_rel_phrase_embed = []
        batch_relation_conn = []
        batch_word_embed = []
        batch_word_to_graph_conn = []

        for idx, sent in enumerate(all_sentences):

            seq = sent['sentence'].lower()
            phrases = sent['phrases']
            phrase_ids = []
            phrase_types = []
            input_phr = []
            lengths = []
            phrase_embeds_list = []

            valid_phrases = filter_phrase(phrases, all_phrase_ids[idx])
            tokenized_seq = seq.split(' ')
            # if flag_flip[idx] == 1:
            #     tokenized_seq = specific_word_replacement(tokenized_seq)

            input_seq_idx = batch_to_ids([tokenized_seq]).to(device_id)
            seq_embeds = self.elmo(input_seq_idx)['elmo_representations'][
                1]  ## 1*L*1024
            seq_embeds, hn = self.seq_rnn(seq_embeds)
            #TODO: encode position
            # seq_embeds = self.pos_enc(seq_embeds)
            word_to_graph_conn = np.zeros(
                (len(valid_phrases), seq_embeds.shape[1]))

            phr_select_ids = []
            for pid, phr in enumerate(valid_phrases):
                phrase_ids.append(phr['phrase_id'])
                phrase_types.append(phr['phrase_type'])
                tokenized_phr = phr['phrase'].lower().split(' ')
                # if flag_flip[idx] == 1:
                #     tokenized_phr = specific_word_replacement(tokenized_phr)

                phr_len = len(tokenized_phr)
                start_ind = phr['first_word_index']

                if self.phrase_select_type == 'Sum':
                    phrase_embeds_list.append(
                        torch.sum(
                            seq_embeds[:, start_ind:start_ind + phr_len, :],
                            1))
                elif self.phrase_select_type == 'Mean':
                    phrase_embeds_list.append(
                        torch.mean(
                            seq_embeds[:, start_ind:start_ind + phr_len, :],
                            1))
                else:
                    raise NotImplementedError('Phrase select type error')

                lengths.append(phr_len)
                input_phr.append(tokenized_phr)
                phr_select_ids.append(pid)
                word_to_graph_conn[pid, start_ind:start_ind + phr_len] = 1

            phrase_embeds = torch.cat(tuple(phrase_embeds_list), 0)

            batch_word_embed.append(seq_embeds[0])
            batch_phrase_ids.append(phrase_ids)
            batch_phrase_types.append(phrase_types)
            batch_phrase_embed.append(phrase_embeds)
            batch_word_to_graph_conn.append(word_to_graph_conn)
            """
            rel phrase embedding
            """
            # get sg
            sent_sg = all_sent_sgs[idx]
            relation_conn = []
            rel_lengths = []
            input_rel_phr = []

            for rel_id, rel in enumerate(sent_sg):
                sbj_id, obj_id, rel_phrase = rel
                if sbj_id not in phrase_ids or obj_id not in phrase_ids:
                    continue
                relation_conn.append([
                    phrase_ids.index(sbj_id),
                    phrase_ids.index(obj_id), rel_id
                ])

                tokenized_phr_rel = rel_phrase.lower().split(' ')
                if cfg.MODEL.RELATION.INCOR_ENTITIES_IN_RELATION:
                    tokenized_phr_rel = input_phr[phrase_ids.index(
                        sbj_id)] + tokenized_phr_rel + input_phr[
                            phrase_ids.index(obj_id)]
                # tokenized_phr_rel = input_phr[phrase_ids.index(sbj_id)] + tokenized_phr_rel + input_phr[phrase_ids.index(obj_id)]

                # if flag_flip[idx] == 1:
                #     tokenized_phr_rel = specific_word_replacement(tokenized_phr_rel)

                rel_phr_len = len(tokenized_phr_rel)
                rel_lengths.append(rel_phr_len)
                input_rel_phr.append(tokenized_phr_rel)

            if len(relation_conn) > 0:
                input_rel_phr_idx = batch_to_ids(input_rel_phr).to(device_id)
                rel_phrase_embeds = self.elmo(
                    input_rel_phr_idx)['elmo_representations'][1]
                # rel_phrase_embeds, _ = self.rel_rnn(rel_phrase_embeds)
                # rel_phrase_embeds, _ = self.seq_rnn(rel_phrase_embeds)
                rel_phrase_embeds = select_embed(
                    rel_phrase_embeds,
                    lengths=rel_lengths,
                    select_type=self.phrase_select_type)
                batch_rel_phrase_embed.append(rel_phrase_embeds)
            else:
                batch_rel_phrase_embed.append(None)

            batch_relation_conn.append(relation_conn)

        return batch_phrase_ids, batch_phrase_types, batch_word_embed, batch_phrase_embed, batch_rel_phrase_embed, batch_relation_conn, batch_word_to_graph_conn
Exemplo n.º 14
0
class PointerGenerator(nn.Module):
    def __init__(self,
                 vocab: List,
                 elmo_weights_file: str,
                 elmo_options_file: str,
                 elmo_embed_dim: int,
                 elmo_sent: bool = False,
                 alignment_model: str = "additive"):
        super(PointerGenerator, self).__init__()

        # Model Properties
        self.elmo_sent = elmo_sent
        self.alignment_model = alignment_model
        self.randomize_init_hidden = True
        self.vocab = sorted(vocab)
        self.vocab_2_ix = {k: v for k, v in zip(self.vocab, range(0, len(self.vocab)))}
        self.ix_2_vocab = {v: k for k, v in self.vocab_2_ix.items()}

        self.map_vocab_2_ix = lambda p_t: [[self.vocab_2_ix[w_t] for w_t in s_t] for s_t in p_t]
        self.map_ix_2_vocab = lambda p_i: [[self.ix_2_vocab[w_i] for w_i in s_i] for s_i in p_i]

        # Model Constants

        self.ELMO_EMBED_DIM = elmo_embed_dim  # This will change if the ELMO options/weights change

        self.VOCAB_SIZE = len(self.vocab)

        # Model Layers
        self.elmo = Elmo(elmo_options_file, elmo_weights_file, 1)
        self.elmo.eval()

        self.encoder = nn.LSTM(input_size=self.ELMO_EMBED_DIM,
                               hidden_size=self.ELMO_EMBED_DIM,
                               num_layers=1,
                               bidirectional=True)

        self.decoder = nn.LSTM(input_size=self.ELMO_EMBED_DIM,
                               hidden_size=2 * self.ELMO_EMBED_DIM,
                               num_layers=1,
                               bidirectional=False)

        self.Wh = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM,
                            out_features=2 * self.ELMO_EMBED_DIM,
                            bias=False)
        self.Ws = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM,
                            out_features=2 * self.ELMO_EMBED_DIM,
                            bias=True)
        self.v = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM,
                           out_features=1,
                           bias=False)

        self.sm_dim0 = nn.Softmax(dim=0)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

        self.Vocab_Project_1 = nn.Linear(in_features=4 * self.ELMO_EMBED_DIM,
                                         out_features=8 * self.ELMO_EMBED_DIM,
                                         bias=True)

        self.Vocab_Project_2 = nn.Linear(in_features=8 * self.ELMO_EMBED_DIM,
                                         out_features=self.VOCAB_SIZE,
                                         bias=True)

        self.Wh_pgen = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM, out_features=1, bias=False)
        self.Ws_pgen = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM, out_features=1, bias=False)
        self.Wx_pgen = nn.Linear(in_features=self.ELMO_EMBED_DIM, out_features=1, bias=True)

    def _elmo_embed_doc(self, doc_tokens: List[List[str]]) -> torch.Tensor:
        if not self.elmo_sent:
            doc_tokens = [[token for sent_tokens in doc_tokens for token in sent_tokens]]

        doc_elmo_ids = batch_to_ids(doc_tokens)
        doc_elmo_embed = self.elmo(doc_elmo_ids)

        if self.elmo_sent:
            _elmo_doc_feats = []
            for sent_elmo_embed, sent_elmo_mask in zip(doc_elmo_embed['elmo_representations'][0],
                                                       doc_elmo_embed['mask']):
                _elmo_doc_feats.append(sent_elmo_embed[:sum(sent_elmo_mask)])
            elmo_doc_feats = torch.cat(_elmo_doc_feats, dim=0)
        else:
            elmo_doc_feats = doc_elmo_embed['elmo_representations'][0][0]
        return elmo_doc_feats

    def _embed_doc(self, doc_tokens: List[List[str]], **kwargs) -> torch.Tensor:
        # Embed the Doc with Elmo
        doc_embedded_elmo = self._elmo_embed_doc(doc_tokens)
        #
        # print("Pre Doc Shape -> {0}".format(doc_embedded_elmo.shape))

        prepend = kwargs.get('prepend_START', None)
        if prepend:
            start_token_elmo = self._elmo_embed_doc([['<START>']])
            doc_embedded_elmo = torch.cat((start_token_elmo, doc_embedded_elmo[:-1]), dim=0)

        # print("Post Doc Shape -> {0}".format(doc_embedded_elmo.shape))

        return doc_embedded_elmo

    def _init_bi_hidden(self, batch_size: int = 1, num_layers: int = 1):
        if self.randomize_init_hidden:
            init_hidden = torch.randn(num_layers * 2, batch_size,
                                      self.ELMO_EMBED_DIM)
        else:
            init_hidden = torch.zeros(num_layers * 2, batch_size,
                                      self.ELMO_EMBED_DIM)
        return init_hidden, init_hidden

    def _run_through_bilstm(self, input_tensor: torch.Tensor, bilstm: torch.nn.modules.rnn):
        init_bi_hidden = self._init_bi_hidden(num_layers=bilstm.num_layers)
        output_tensor, _ = bilstm(input_tensor, init_bi_hidden)
        output_tensor = output_tensor.view(input_tensor.shape[0], 1, 2, bilstm.hidden_size)
        output_tensor = torch.cat((output_tensor[:, :, 0, :],
                                   output_tensor[:, :, 1, :]),
                                  dim=2).squeeze(dim=1)
        return output_tensor

    def _align(self, s, h, alignment_model="additive"):
        if alignment_model == "additive":
            # Attention Alignment Model from Bahdanau et al(2015)
            e = self.v(self.tanh(self.Wh(h) + self.Ws(s))).squeeze()
        elif alignment_model == "dot_product":
            # Attention Alignment Model from Luong et al(2015)
            e = torch.matmul(h, s.squeeze(dim=0))
        return e

    def _decoder_train(self, encoder_states, src_tokens, tgt_tokens):
        _init_probe = encoder_states[-1].reshape(1, 1, -1)
        curr_h, curr_c = (_init_probe, torch.randn_like(_init_probe))

        tgt_elmo = self._embed_doc(tgt_tokens, prepend_START=True)

        flat_src_tokens = [i for j in src_tokens for i in j]

        assert len(flat_src_tokens) == encoder_states.shape[0]

        new_words = sorted(list(set([w for w in flat_src_tokens if w not in self.vocab])))

        extended_vocab = self.vocab + new_words
        extended_vocab_2_ix = {**self.vocab_2_ix, **{w: ix for w, ix in zip(new_words, range(
            len(self.vocab), len(extended_vocab)))}}
        extended_ix_2_vocab = {v: k for k, v in extended_vocab_2_ix.items()}

        assert len(extended_vocab) == len(extended_vocab_2_ix) == len(extended_ix_2_vocab)

        # To calculate loss
        collect_xtnd_vocab_prjtns = []

        for curr_elmo in tgt_elmo:
            p_vocab = torch.zeros(size=(1, len(extended_vocab)))
            p_attn = torch.zeros(size=(1, len(extended_vocab)))

            curr_i = curr_elmo.reshape(1, 1, -1)
            curr_o, (curr_h, curr_c) = self.decoder(curr_i, (curr_h, curr_c))

            # Calculate Context Vector
            curr_attn = self._align(s=curr_h.squeeze(dim=1), h=encoder_states,
                                    alignment_model=self.alignment_model)
            curr_attn = self.sm_dim0(curr_attn)
            curr_ctxt = torch.matmul(curr_attn, encoder_states)

            # Concatenate Context & Decoder Hidden State
            state_ctxt_concat = torch.cat((curr_h.squeeze(), curr_ctxt))

            # Project to Vocabulary
            vocab_prjtn = self.Vocab_Project_2(self.Vocab_Project_1(state_ctxt_concat))
            p_vocab[:, :self.VOCAB_SIZE] = vocab_prjtn
            for src_word, src_attn in zip(flat_src_tokens, curr_attn):
                p_attn[:, extended_vocab_2_ix[src_word]] += src_attn

            p_gen = self.sigmoid(
                self.Wh_pgen(curr_ctxt) + self.Ws_pgen(curr_h.squeeze()) + self.Wx_pgen(curr_i.squeeze()))

            p_W = p_gen * p_vocab + (1 - p_gen) * p_attn
            collect_xtnd_vocab_prjtns.append(p_W)

        collect_xtnd_vocab_prjtns = torch.cat(collect_xtnd_vocab_prjtns, dim=0)

        return (collect_xtnd_vocab_prjtns, extended_vocab_2_ix, extended_ix_2_vocab)

    def _decoder_test(self, encoder_states, src_tokens, len_of_summary: int):
        _init_probe = encoder_states[-1].reshape(1, 1, -1)
        curr_h, curr_c = (_init_probe, torch.randn_like(_init_probe))

        flat_src_tokens = [i for j in src_tokens for i in j]

        assert len(flat_src_tokens) == encoder_states.shape[0]

        new_words = sorted(list(set([w for w in flat_src_tokens if w not in self.vocab])))

        extended_vocab = self.vocab + new_words
        _extend_2_ix = {w: ix for w, ix in zip(new_words, range(len(self.vocab), len(extended_vocab)))}
        extended_vocab_2_ix = {**self.vocab_2_ix, **_extend_2_ix}
        extended_ix_2_vocab = {v: k for k, v in extended_vocab_2_ix.items()}

        assert len(extended_vocab) == len(extended_vocab_2_ix) == len(extended_ix_2_vocab)

        # for curr_elmo in tgt_elmo:
        collected_summary_tokens = [['<START>']]
        curr_elmo = self._embed_doc(doc_tokens=collected_summary_tokens)

        collect_xtnd_vocab_prjtns = []

        for token_ix in range(len_of_summary):
            p_vocab = torch.zeros(size=(1, len(extended_vocab)))
            p_attn = torch.zeros(size=(1, len(extended_vocab)))

            curr_i = curr_elmo.reshape(1, 1, -1)
            curr_o, (curr_h, curr_c) = self.decoder(curr_i, (curr_h, curr_c))

            # Calculate Context Vector
            curr_attn = self._align(s=curr_h.squeeze(dim=1), h=encoder_states,
                                    alignment_model=self.alignment_model)
            curr_attn = self.sm_dim0(curr_attn)
            curr_ctxt = torch.matmul(curr_attn, encoder_states)

            # Concatenate Context & Decoder Hidden State
            state_ctxt_concat = torch.cat((curr_h.squeeze(), curr_ctxt))

            # Project to Vocabulary
            vocab_prjtn = self.Vocab_Project_2(self.Vocab_Project_1(state_ctxt_concat))
            p_vocab[:, :self.VOCAB_SIZE] = vocab_prjtn
            for src_word, src_attn in zip(flat_src_tokens, curr_attn):
                p_attn[:, extended_vocab_2_ix[src_word]] += src_attn

            p_gen = self.sigmoid(
                self.Wh_pgen(curr_ctxt) + self.Ws_pgen(curr_h.squeeze()) + self.Wx_pgen(curr_i.squeeze()))

            p_W = p_gen * p_vocab + (1 - p_gen) * p_attn
            collect_xtnd_vocab_prjtns.append(p_W)

            curr_pred_token = extended_ix_2_vocab[p_W.argmax(dim=1).item()]
            collected_summary_tokens[-1].append(curr_pred_token)

            # Just get the Elmo Embedding of the Latest Word of the Latest Sentence
            curr_elmo = self._embed_doc([collected_summary_tokens[-1]])[-1]

            if curr_pred_token == '.':
                # Start a New Line
                collected_summary_tokens.append([])

        collect_xtnd_vocab_prjtns = torch.cat(collect_xtnd_vocab_prjtns, dim=0)

        return (collect_xtnd_vocab_prjtns, extended_vocab_2_ix, extended_ix_2_vocab)

    def forward(self, orig_text_tokens: List[List[str]], **kwargs) -> Union:
        # Embed the Orig with Elmo
        orig_embedded_elmo = self._embed_doc(orig_text_tokens)

        # Encode with BiLSTM
        orig_embedded_elmo.unsqueeze_(dim=1)
        encoder_states = self._run_through_bilstm(orig_embedded_elmo, self.encoder)

        # summ_text implies training
        summ_text_tokens = kwargs.get('summ_text_tokens', None)
        summary_length = kwargs.get('summ_text_length', None)

        if summ_text_tokens:
            # -> Training Loop
            print("Training")
            prjtns, v2i, i2v = self._decoder_train(encoder_states=encoder_states,
                                                   src_tokens=orig_text_tokens,
                                                   tgt_tokens=summ_text_tokens)
        else:
            # -> Inference Loop
            print("Testing")
            target_length = 30
            if summary_length:
                target_length = summary_length
            prjtns, v2i, i2v = self._decoder_test(encoder_states=encoder_states,
                                                  src_tokens=orig_text_tokens,
                                                  len_of_summary=target_length)
        return (prjtns, v2i, i2v)
Exemplo n.º 15
0
class ContextualControllerELMo(ControllerBase):
    def __init__(
            self,
            hidden_size,
            dropout,
            pretrained_embeddings_dir,
            dataset_name,
            fc_hidden_size=150,
            freeze_pretrained=True,
            learning_rate=0.001,
            layer_learning_rate: Optional[Dict[str, float]] = None,
            max_segment_size=None,  # if None, process sentences independently
            max_span_size=10,
            model_name=None):
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.freeze_pretrained = freeze_pretrained
        self.fc_hidden_size = fc_hidden_size
        self.max_span_size = max_span_size
        self.max_segment_size = max_segment_size
        self.learning_rate = learning_rate
        self.layer_learning_rate = layer_learning_rate if layer_learning_rate is not None else {}

        self.pretrained_embeddings_dir = pretrained_embeddings_dir
        self.embedder = Elmo(
            options_file=os.path.join(pretrained_embeddings_dir,
                                      "options.json"),
            weight_file=os.path.join(pretrained_embeddings_dir,
                                     "slovenian-elmo-weights.hdf5"),
            dropout=(0.0 if freeze_pretrained else dropout),
            num_output_representations=1,
            requires_grad=(not freeze_pretrained)).to(DEVICE)
        embedding_size = self.embedder.get_output_dim()

        self.context_encoder = nn.LSTM(input_size=embedding_size,
                                       hidden_size=hidden_size,
                                       batch_first=True,
                                       bidirectional=True).to(DEVICE)
        self.scorer = NeuralCoreferencePairScorer(num_features=(2 *
                                                                hidden_size),
                                                  hidden_size=fc_hidden_size,
                                                  dropout=dropout).to(DEVICE)
        params_to_update = [{
            "params":
            self.scorer.parameters(),
            "lr":
            self.layer_learning_rate.get("lr_scorer", self.learning_rate)
        }, {
            "params":
            self.context_encoder.parameters(),
            "lr":
            self.layer_learning_rate.get("lr_context_encoder",
                                         self.learning_rate)
        }]
        if not freeze_pretrained:
            params_to_update.append({
                "params":
                self.embedder.parameters(),
                "lr":
                self.layer_learning_rate.get("lr_embedder", self.learning_rate)
            })

        self.optimizer = optim.Adam(params_to_update, lr=self.learning_rate)

        super().__init__(learning_rate=learning_rate,
                         dataset_name=dataset_name,
                         model_name=model_name)
        logging.info(
            f"Initialized contextual ELMo-based model with name {self.model_name}."
        )

    @property
    def model_base_dir(self):
        return "contextual_model_elmo"

    def train_mode(self):
        if not self.freeze_pretrained:
            self.embedder.train()
        self.context_encoder.train()
        self.scorer.train()

    def eval_mode(self):
        self.embedder.eval()
        self.context_encoder.eval()
        self.scorer.eval()

    def load_checkpoint(self):
        self.loaded_from_file = True
        self.context_encoder.load_state_dict(
            torch.load(os.path.join(self.path_model_dir, "context_encoder.th"),
                       map_location=DEVICE))
        self.scorer.load_state_dict(
            torch.load(os.path.join(self.path_model_dir, "scorer.th"),
                       map_location=DEVICE))

        path_to_embeddings = os.path.join(self.path_model_dir, "embeddings.th")
        if os.path.isfile(path_to_embeddings):
            logging.info(
                f"Loading fine-tuned ELMo weights from '{path_to_embeddings}'")
            self.embedder.load_state_dict(
                torch.load(path_to_embeddings, map_location=DEVICE))

    @staticmethod
    def from_pretrained(model_dir):
        controller_config_path = os.path.join(model_dir,
                                              "controller_config.json")
        with open(controller_config_path, "r", encoding="utf-8") as f_config:
            pre_config = json.load(f_config)

        instance = ContextualControllerELMo(**pre_config)
        instance.load_checkpoint()

        return instance

    def save_pretrained(self, model_dir):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        # Write controller config (used for instantiation)
        controller_config_path = os.path.join(model_dir,
                                              "controller_config.json")
        with open(controller_config_path, "w", encoding="utf-8") as f_config:
            json.dump(
                {
                    "hidden_size": self.hidden_size,
                    "dropout": self.dropout,
                    "pretrained_embeddings_dir":
                    self.pretrained_embeddings_dir,
                    "dataset_name": self.dataset_name,
                    "fc_hidden_size": self.fc_hidden_size,
                    "freeze_pretrained": self.freeze_pretrained,
                    "learning_rate": self.learning_rate,
                    "layer_learning_rate": self.layer_learning_rate,
                    "max_segment_size": self.max_segment_size,
                    "max_span_size": self.max_span_size,
                    "model_name": self.model_name
                },
                fp=f_config,
                indent=4)

        torch.save(self.context_encoder.state_dict(),
                   os.path.join(self.path_model_dir, "context_encoder.th"))
        torch.save(self.scorer.state_dict(),
                   os.path.join(self.path_model_dir, "scorer.th"))

        # Save fine-tuned ELMo embeddings only if they're not frozen
        if not self.freeze_pretrained:
            torch.save(self.embedder.state_dict(),
                       os.path.join(self.path_model_dir, "embeddings.th"))

    def save_checkpoint(self):
        logging.warning(
            "save_checkpoint() is deprecated. Use save_pretrained() instead")
        self.save_pretrained(self.path_model_dir)

    def _prepare_doc(self, curr_doc: Document) -> Dict:
        """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since
        data inside same document does not get shuffled. """
        ret = {}

        # By default, each sentence is its own segment, meaning sentences are processed independently
        if self.max_segment_size is None:

            def get_position(t):
                return t.sentence_index, t.position_in_sentence

            _encoded_segments = batch_to_ids(curr_doc.raw_sentences())
        # Optionally, one can specify max_segment_size, in which case segments of tokens are processed independently
        else:

            def get_position(t):
                doc_position = t.position_in_document
                return doc_position // self.max_segment_size, doc_position % self.max_segment_size

            flattened_doc = list(chain(*curr_doc.raw_sentences()))
            num_segments = (len(flattened_doc) + self.max_segment_size -
                            1) // self.max_segment_size
            _encoded_segments = \
                batch_to_ids([flattened_doc[idx_seg * self.max_segment_size: (idx_seg + 1) * self.max_segment_size]
                              for idx_seg in range(num_segments)])

        encoded_segments = []
        # Convention: Add a PAD word ([0] * max_chars vector) at the end of each segment, for padding mentions
        for curr_sent in _encoded_segments:
            encoded_segments.append(
                torch.cat((curr_sent,
                           torch.zeros(
                               (1, ELMoCharacterMapper.max_word_length),
                               dtype=torch.long))))
        encoded_segments = torch.stack(encoded_segments)

        cluster_sets = []
        mention_to_cluster_id = {}
        for i, curr_cluster in enumerate(curr_doc.clusters):
            cluster_sets.append(set(curr_cluster))
            for mid in curr_cluster:
                mention_to_cluster_id[mid] = i

        all_candidate_data = []
        for idx_head, (head_id,
                       head_mention) in enumerate(curr_doc.mentions.items(),
                                                  1):
            gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]]

            # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`)
            candidates, candidate_data = [None], []
            candidate_attention = []
            correct_antecedents = []

            curr_head_data = [[], []]
            num_head_words = 0
            for curr_token in head_mention.tokens:
                idx_segment, idx_inside_segment = get_position(curr_token)
                curr_head_data[0].append(idx_segment)
                curr_head_data[1].append(idx_inside_segment)
                num_head_words += 1

            if num_head_words > self.max_span_size:
                curr_head_data[0] = curr_head_data[0][:self.max_span_size]
                curr_head_data[1] = curr_head_data[1][:self.max_span_size]
            else:
                curr_head_data[0] += [curr_head_data[0][-1]
                                      ] * (self.max_span_size - num_head_words)
                curr_head_data[1] += [-1
                                      ] * (self.max_span_size - num_head_words)

            head_attention = torch.ones((1, self.max_span_size),
                                        dtype=torch.bool)
            head_attention[0, num_head_words:] = False

            for idx_candidate, (cand_id, cand_mention) in enumerate(
                    curr_doc.mentions.items(), start=1):
                if idx_candidate >= idx_head:
                    break

                candidates.append(cand_id)

                # Maps tokens to positions inside segments (idx_seg, idx_inside_seg) for efficient indexing later
                curr_candidate_data = [[], []]
                num_candidate_words = 0
                for curr_token in cand_mention.tokens:
                    idx_segment, idx_inside_segment = get_position(curr_token)
                    curr_candidate_data[0].append(idx_segment)
                    curr_candidate_data[1].append(idx_inside_segment)
                    num_candidate_words += 1

                if num_candidate_words > self.max_span_size:
                    curr_candidate_data[0] = curr_candidate_data[
                        0][:self.max_span_size]
                    curr_candidate_data[1] = curr_candidate_data[
                        1][:self.max_span_size]
                else:
                    # padding tokens index into the PAD token of the last segment
                    curr_candidate_data[0] += [curr_candidate_data[0][-1]] * (
                        self.max_span_size - num_candidate_words)
                    curr_candidate_data[1] += [-1] * (self.max_span_size -
                                                      num_candidate_words)

                candidate_data.append(curr_candidate_data)
                curr_attention = torch.ones((1, self.max_span_size),
                                            dtype=torch.bool)
                curr_attention[0, num_candidate_words:] = False
                candidate_attention.append(curr_attention)

                is_coreferent = cand_id in gt_antecedent_ids
                if is_coreferent:
                    correct_antecedents.append(idx_candidate)

            if len(correct_antecedents) == 0:
                correct_antecedents.append(0)

            candidate_attention = torch.cat(
                candidate_attention) if len(candidate_attention) > 0 else []
            all_candidate_data.append({
                "head_id":
                head_id,
                "head_data":
                torch.tensor([curr_head_data]),
                "head_attention":
                head_attention,
                "candidates":
                candidates,
                "candidate_data":
                torch.tensor(candidate_data),
                "candidate_attention":
                candidate_attention,
                "correct_antecedents":
                correct_antecedents
            })

        ret["preprocessed_segments"] = encoded_segments
        ret["steps"] = all_candidate_data

        return ret

    def _train_doc(self, curr_doc, eval_mode=False):
        """ Trains/evaluates (if `eval_mode` is True) model on specific document.
            Returns predictions, loss and number of examples evaluated. """

        if len(curr_doc.mentions) == 0:
            return {}, (0.0, 0)

        if not hasattr(curr_doc, "_cache_elmo"):
            curr_doc._cache_elmo = self._prepare_doc(curr_doc)
        cache = curr_doc._cache_elmo  # type: Dict

        encoded_segments = cache["preprocessed_segments"]
        if self.freeze_pretrained:
            with torch.no_grad():
                res = self.embedder(encoded_segments.to(DEVICE))
        else:
            res = self.embedder(encoded_segments.to(DEVICE))

        # Note: max_segment_size is either specified at instantiation or (the length of longest sentence + 1)
        embedded_segments = res["elmo_representations"][
            0]  # [num_segments, max_segment_size, embedding_size]
        (lstm_segments, _) = self.context_encoder(
            embedded_segments
        )  # [num_segments, max_segment_size, 2 * hidden_size]

        doc_loss, n_examples = 0.0, len(cache["steps"])
        preds = {}

        for curr_step in cache["steps"]:
            head_id = curr_step["head_id"]
            head_data = curr_step["head_data"]

            candidates = curr_step["candidates"]
            candidate_data = curr_step["candidate_data"]
            correct_antecedents = curr_step["correct_antecedents"]

            # Note: num_candidates includes dummy antecedent + actual candidates
            num_candidates = len(candidates)
            if num_candidates == 1:
                curr_pred = 0
            else:
                idx_segment = candidate_data[:, 0, :]
                idx_in_segment = candidate_data[:, 1, :]

                # [num_candidates, max_span_size, embedding_size]
                candidate_data = lstm_segments[idx_segment, idx_in_segment]
                # [1, head_size, embedding_size]
                head_data = lstm_segments[head_data[:, 0, :], head_data[:,
                                                                        1, :]]
                head_data = head_data.repeat((num_candidates - 1, 1, 1))

                candidate_scores = self.scorer(
                    candidate_data, head_data,
                    curr_step["candidate_attention"],
                    curr_step["head_attention"].repeat(
                        (num_candidates - 1, 1)))

                # [1, num_candidates]
                candidate_scores = torch.cat(
                    (torch.tensor([0.0], device=DEVICE),
                     candidate_scores.flatten())).unsqueeze(0)

                curr_pred = torch.argmax(candidate_scores)
                doc_loss += self.loss(
                    candidate_scores.repeat((len(correct_antecedents), 1)),
                    torch.tensor(correct_antecedents, device=DEVICE))

            # { antecedent: [mention(s)] } pair
            existing_refs = preds.get(candidates[int(curr_pred)], [])
            existing_refs.append(head_id)
            preds[candidates[int(curr_pred)]] = existing_refs

        if not eval_mode:
            doc_loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

        return preds, (float(doc_loss), n_examples)
Exemplo n.º 16
0
    X, y, EMO_LIST, NUM_EMO = cbet_data_other('median',
                                              remove_stop_words=False)
else:
    raise Exception('Dataset not recognized :(')

if opt.elmo == 'origin':
    options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
    weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
elif opt.elmo == 'origin55b':
    options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json'
    weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5'
else:
    raise Exception('elmo model not recognized')

elmo = Elmo(options_file, weight_file, 2, dropout=0).cuda()
elmo.eval()

EMOS = EMO_LIST
EMOS_DIC = dict(zip(EMOS, range(len(EMOS))))

tokenizer = GloveTokenizer()

# deepmoji
print('Tokenizing using dictionary from {}'.format(VOCAB_PATH))
with open(VOCAB_PATH, 'r') as f:
    vocabulary = json.load(f)
st = SentenceTokenizer(vocabulary, PAD_LEN)

print('Loading model from {}.'.format(PRETRAINED_PATH))
emoji_model = torchmoji_feature_encoding(PRETRAINED_PATH)
emoji_model.eval()
Exemplo n.º 17
0
        setting["parsed_data_path"]["unlabeled"]
    ]
    print("base_dirs are", base_dirs)

    corpus = ParsedCorpus(base_dirs)

    sentences_generator = corpus.get_single("sentences")
    corefs_generator = corpus.get_single("corefs")

    # if you are looking for example, please see https://allennlp.org/elmo
    # options_file = "/path/to/options.json"
    # weight_file = "path/to/weights.hdf5"
    options_file = args.options_file
    weight_file = args.weight_file
    encoder = Elmo(options_file, weight_file, 1, dropout=0)
    encoder.eval()
    encoder.cuda()

    pbar = tqdm.tqdm(range(len(corpus)))
    for _ in pbar:
        sentences, file_name = next(sentences_generator)
        corefs, _ = next(corefs_generator)
        save_name = file_name + ".pt"
        if os.path.exists(save_name):
            pbar.set_description_str(f"{save_name} exists, skipping"[-30:])
            continue
        # preprocess all sentences in a document
        doc = []
        for sentence in sentences:
            sentence = [token["word"] for token in sentence["tokens"]]
            doc.append(sentence)
Exemplo n.º 18
0
class PointerGenerator(LightningModule):
    def __init__(self,
                 vocab: List,
                 elmo_weights_file: str,
                 elmo_options_file: str,
                 elmo_embed_dim: int,
                 elmo_sent: bool = False,
                 alignment_model: str = "additive"):
        super().__init__()
        self.save_hyperparameters()

        # Model Properties
        self.elmo_sent = elmo_sent
        self.alignment_model = alignment_model
        self.randomize_init_hidden = True
        self.vocab = sorted(vocab)
        self.vocab_2_ix = {k: v for k, v in zip(self.vocab, range(0, len(self.vocab)))}
        self.ix_2_vocab = {v: k for k, v in self.vocab_2_ix.items()}

        self.map_vocab_2_ix = lambda p_t: [[self.vocab_2_ix[w_t] for w_t in s_t] for s_t in p_t]
        self.map_ix_2_vocab = lambda p_i: [[self.ix_2_vocab[w_i] for w_i in s_i] for s_i in p_i]

        # Model Constants

        self.ELMO_EMBED_DIM = elmo_embed_dim  # This will change if the ELMO options/weights change

        self.VOCAB_SIZE = len(self.vocab)

        # Model Layers
        self.elmo = Elmo(elmo_options_file, elmo_weights_file, 1)
        self.elmo.eval()

        self.encoder = nn.LSTM(input_size=self.ELMO_EMBED_DIM,
                               hidden_size=self.ELMO_EMBED_DIM,
                               num_layers=1,
                               bidirectional=True)

        self.decoder = nn.LSTM(input_size=self.ELMO_EMBED_DIM,
                               hidden_size=2 * self.ELMO_EMBED_DIM,
                               num_layers=1,
                               bidirectional=False)

        self.Wh = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM,
                            out_features=2 * self.ELMO_EMBED_DIM,
                            bias=False)
        self.Ws = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM,
                            out_features=2 * self.ELMO_EMBED_DIM,
                            bias=True)
        self.v = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM,
                           out_features=1,
                           bias=False)

        self.sm_dim0 = nn.Softmax(dim=0)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

        self.Vocab_Project_1 = nn.Linear(in_features=4 * self.ELMO_EMBED_DIM,
                                         out_features=8 * self.ELMO_EMBED_DIM,
                                         bias=True)

        self.Vocab_Project_2 = nn.Linear(in_features=8 * self.ELMO_EMBED_DIM,
                                         out_features=self.VOCAB_SIZE,
                                         bias=True)

        self.Wh_pgen = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM, out_features=1, bias=False)
        self.Ws_pgen = nn.Linear(in_features=2 * self.ELMO_EMBED_DIM, out_features=1, bias=False)
        self.Wx_pgen = nn.Linear(in_features=self.ELMO_EMBED_DIM, out_features=1, bias=True)

    def _elmo_embed_doc(self, doc_tokens: List[List[str]]) -> torch.Tensor:
        if not self.elmo_sent:
            doc_tokens = [[token for sent_tokens in doc_tokens for token in sent_tokens]]

        doc_elmo_ids = batch_to_ids(doc_tokens)
        doc_elmo_embed = self.elmo(doc_elmo_ids)

        if self.elmo_sent:
            _elmo_doc_feats = []
            for sent_elmo_embed, sent_elmo_mask in zip(doc_elmo_embed['elmo_representations'][0],
                                                       doc_elmo_embed['mask']):
                _elmo_doc_feats.append(sent_elmo_embed[:sum(sent_elmo_mask)])
            elmo_doc_feats = torch.cat(_elmo_doc_feats, dim=0)
        else:
            elmo_doc_feats = doc_elmo_embed['elmo_representations'][0][0]
        return elmo_doc_feats

    def _embed_doc(self, doc_tokens: List[List[str]], **kwargs) -> torch.Tensor:
        # Embed the Doc with Elmo
        doc_embedded_elmo = self._elmo_embed_doc(doc_tokens)
        #
        # print("Pre Doc Shape -> {0}".format(doc_embedded_elmo.shape))

        prepend = kwargs.get('prepend_START', None)
        if prepend:
            start_token_elmo = self._elmo_embed_doc([['<START>']])
            doc_embedded_elmo = torch.cat((start_token_elmo, doc_embedded_elmo[:-1]), dim=0)

        # print("Post Doc Shape -> {0}".format(doc_embedded_elmo.shape))

        return doc_embedded_elmo

    def _init_bi_hidden(self, batch_size: int = 1, num_layers: int = 1):
        if self.randomize_init_hidden:
            init_hidden = torch.randn(num_layers * 2, batch_size,
                                      self.ELMO_EMBED_DIM)
        else:
            init_hidden = torch.zeros(num_layers * 2, batch_size,
                                      self.ELMO_EMBED_DIM)
        return init_hidden, init_hidden

    def _run_through_bilstm(self, input_tensor: torch.Tensor, bilstm: torch.nn.modules.rnn):
        init_bi_hidden = self._init_bi_hidden(num_layers=bilstm.num_layers)
        output_tensor, _ = bilstm(input_tensor, init_bi_hidden)
        output_tensor = output_tensor.view(input_tensor.shape[0], 1, 2, bilstm.hidden_size)
        output_tensor = torch.cat((output_tensor[:, :, 0, :],
                                   output_tensor[:, :, 1, :]),
                                  dim=2).squeeze(dim=1)
        return output_tensor

    def _align(self, s, h, alignment_model="additive"):
        if alignment_model == "additive":
            # Attention Alignment Model from Bahdanau et al(2015)
            e = self.v(self.tanh(self.Wh(h) + self.Ws(s))).squeeze()
        elif alignment_model == "dot_product":
            # Attention Alignment Model from Luong et al(2015)
            e = torch.matmul(h, s.squeeze(dim=0))
        return e

    def _extend_vocab(self, possible_new_tokens):
        new_words = sorted(list(set([w for w in possible_new_tokens if w not in self.vocab])))
        extended_vocab = self.vocab + new_words
        extended_vocab_2_ix = {**self.vocab_2_ix, **{w: ix for w, ix in zip(new_words, range(
            len(self.vocab), len(extended_vocab)))}}
        extended_ix_2_vocab = {v: k for k, v in extended_vocab_2_ix.items()}

        assert len(extended_vocab) == len(extended_vocab_2_ix) == len(extended_ix_2_vocab), "Vocab Length Mismatch"

        return extended_vocab, extended_vocab_2_ix, extended_ix_2_vocab

    # Lightning Methods
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self) -> DataLoader:
        train_dataset = loader.CNNLoader(path_to_csv='dataset/gttp_train.csv')
        train_loader = DataLoader(dataset=train_dataset,
                                  shuffle=True,
                                  batch_size=constant.BATCH_SIZE,
                                  num_workers=4)
        return train_loader

    def training_step(self, batch, batch_nb):
        batch_loss = 0
        batch_orig, batch_summ = batch
        for orig_text, summ_text in zip(batch_orig, batch_summ):
            summ_tokens = helper.tokenize_en(summ_text, lowercase=True)
            summ_tokens_flat = [i for j in summ_tokens for i in j]

            prjtns, v2i, _ = self(orig_text=orig_text, summ_text=summ_text)
            gold_ixs = torch.LongTensor([v2i.get(w, constant.UNK_TOK_IX) for w in summ_tokens_flat])

            batch_loss += loss_fn(input=prjtns, target=gold_ixs)

        return {'loss': batch_loss}

    def training_epoch_end(self, training_step_outputs):
        mean_train_loss = torch.stack([x['loss'] for x in training_step_outputs]).mean()
        return {
            'log': {'loss/train': mean_train_loss, 'step': self.current_epoch},
            'progress_bar': {'loss/train': mean_train_loss}
        }

    def val_dataloader(self) -> DataLoader:
        validation_dataset = loader.CNNLoader(path_to_csv='dataset/gttp_valid.csv')
        validation_loader = DataLoader(dataset=validation_dataset,
                                       shuffle=False,
                                       batch_size=constant.BATCH_SIZE,
                                       num_workers=4)
        return validation_loader

    def validation_step(self, batch, batch_nb):
        batch_loss = 0
        batch_orig, batch_summ = batch
        for orig_text, summ_text in zip(batch_orig, batch_summ):
            summ_tokens = helper.tokenize_en(summ_text, lowercase=True)
            summ_tokens_flat = [i for j in summ_tokens for i in j]

            prjtns, v2i, _ = self(orig_text=orig_text, summ_text=summ_text)
            gold_ixs = torch.LongTensor([v2i.get(w, constant.UNK_TOK_IX) for w in summ_tokens_flat])
            batch_loss += loss_fn(input=prjtns, target=gold_ixs)

        batch_loss /= len(batch_orig)

        return {'loss': batch_loss}

    def validation_epoch_end(self, validation_step_outputs):
        mean_validation_loss = torch.stack([x['loss'] for x in validation_step_outputs]).mean()
        return {
            'log': {'loss/validation': mean_validation_loss, 'step': self.current_epoch},
            'progress_bar': {'loss/validation': mean_validation_loss}
        }

    def forward(self, orig_text: str, **kwargs) -> Union:
        orig_tokens = helper.tokenize_en(orig_text, lowercase=True)

        # Extend the vocabulary to include new words
        orig_tokens_flat = [i for j in orig_tokens for i in j]
        ex_vocab, ex_vocab_2_ix, ex_ix_2_vocab = self._extend_vocab(possible_new_tokens=orig_tokens_flat)

        # Embed the Orig with Elmo
        orig_elmo = self._embed_doc(orig_tokens)
        # Encode with BiLSTM
        orig_elmo.unsqueeze_(dim=1)
        encoder_states = self._run_through_bilstm(orig_elmo, self.encoder)
        assert len(orig_tokens_flat) == encoder_states.shape[0]

        # summ_text implies Training
        summ_text = kwargs.get('summ_text', None)

        if summ_text:
            # -> Training Loop
            summ_tokens = helper.tokenize_en(summ_text, lowercase=True)
            summ_elmo = self._embed_doc(summ_tokens, prepend_START=True)
            summ_len = len(summ_elmo)
        else:
            # -> Inference Loop
            summ_len = kwargs.get('summ_len', None)
            generated_summ_tokens = [['<START>']]

        # To calculate loss
        vocab_prjtns = []
        _init_probe = encoder_states[-1].reshape(1, 1, -1)
        curr_deco_state = (_init_probe, torch.randn_like(_init_probe))
        curr_pred_token = None
        for token_ix in range(summ_len):
            if summ_text is not None:
                curr_i = summ_elmo[token_ix].reshape(1, 1, -1)
            elif curr_pred_token is not None:
                # Append currently predicted token
                generated_summ_tokens[-1].append(curr_pred_token)
                # Just get the Elmo Embedding of the Last Word of the Last Sentence
                curr_i = self._embed_doc([generated_summ_tokens[-1]])[-1].reshape(1, 1, -1)
                # Start a New Line if necessary
                if curr_pred_token == '.':
                    generated_summ_tokens.append([])
            else:
                # Init input for prediction
                curr_i = self._embed_doc([generated_summ_tokens[-1]])[-1].reshape(1, 1, -1)

            p_vocab = torch.zeros(size=(1, len(ex_vocab)))
            p_attn = torch.zeros(size=(1, len(ex_vocab)))

            # Run through the decoder
            curr_embed_output, curr_deco_state = self.decoder(curr_i, curr_deco_state)

            # Extract the hidden state vector
            curr_deco_hidd, _ = curr_deco_state

            # Calculate Context Vector
            curr_enco_attn = self._align(s=curr_deco_hidd.squeeze(dim=1),
                                         h=encoder_states,
                                         alignment_model=self.alignment_model)
            curr_enco_attn = self.sm_dim0(curr_enco_attn)
            curr_enco_ctxt = torch.matmul(curr_enco_attn, encoder_states)

            # Concatenate Context & Decoder Hidden State
            state_ctxt_concat = torch.cat((curr_deco_hidd.squeeze(), curr_enco_ctxt))

            # Project to Vocabulary
            vocab_prjtn = self.Vocab_Project_2(self.Vocab_Project_1(state_ctxt_concat))
            p_vocab[:, :self.VOCAB_SIZE] = vocab_prjtn
            for src_word, src_attn in zip(orig_tokens_flat, curr_enco_attn):
                p_attn[:, ex_vocab_2_ix[src_word]] += src_attn

            p_gen = self.sigmoid(
                self.Wh_pgen(curr_enco_ctxt) + self.Ws_pgen(curr_deco_hidd.squeeze()) + self.Wx_pgen(
                    curr_i.squeeze()))

            p_W = p_gen * p_vocab + (1 - p_gen) * p_attn
            curr_pred_token = ex_ix_2_vocab[p_W.argmax(dim=1).item()]

            vocab_prjtns.append(p_W)

        vocab_prjtns = torch.cat(vocab_prjtns, dim=0)

        return (vocab_prjtns, ex_vocab_2_ix, ex_ix_2_vocab)