Example #1
0
class ElmoEmbedding(torch.nn.Module):
    """
    Packed elmo.
    """
    def __init__(self,
                 name_or_path: str,
                 num_output_representations: int = 1,
                 requires_grad: bool = False,
                 do_layer_norm: bool = False,
                 dropout: float = 0.5,
                 keep_sentence_boundaries: bool = False,
                 scalar_mix_parameters: List[float] = None,
                 module: torch.nn.Module = None,
                 **kwargs):
        super().__init__()
        self.elmo = Elmo(name_or_path + 'options.json',
                         name_or_path + 'weights.hdf5',
                         num_output_representations,
                         requires_grad,
                         do_layer_norm,
                         dropout,
                         keep_sentence_boundaries=keep_sentence_boundaries,
                         scalar_mix_parameters=scalar_mix_parameters,
                         module=module)
        self.get_output = lambda x: x['elmo_representations'][0]
        self.batch_to_ids = batch_to_ids
        self.output_dim = self.elmo.get_output_dim()

    def forward(self, input_ids: torch.Tensor, sentences: List[List[str]], **kwargs):
        character_ids = self.batch_to_ids(sentences).to(input_ids.device)
        elmo_output = self.elmo(character_ids)  # 出来的mask和之前生成的一样的,没用
        return self.get_output(elmo_output)
Example #2
0
    def __init__(
        self,
        options_files: Dict[str, str],
        weight_files: Dict[str, str],
        do_layer_norm: bool = False,
        dropout: float = 0.5,
        requires_grad: bool = False,
        projection_dim: int = None,
        vocab_to_cache: List[str] = None,
        scalar_mix_parameters: List[float] = None,
        aligning_files: Dict[str, str] = None,
    ) -> None:
        super().__init__()

        if options_files.keys() != weight_files.keys():
            raise ConfigurationError("Keys for Elmo's options files and weights files don't match")

        aligning_files = aligning_files or {}
        output_dim = None
        for lang in weight_files.keys():
            name = "elmo_%s" % lang
            elmo = Elmo(
                options_files[lang],
                weight_files[lang],
                num_output_representations=1,
                do_layer_norm=do_layer_norm,
                dropout=dropout,
                requires_grad=requires_grad,
                vocab_to_cache=vocab_to_cache,
                scalar_mix_parameters=scalar_mix_parameters,
            )
            self.add_module(name, elmo)

            output_dim_tmp = elmo.get_output_dim()
            if output_dim is not None:
                # Verify that all ELMo embedders have the same output dimension.
                check_dimensions_match(
                    output_dim_tmp, output_dim, "%s output dim" % name, "elmo output dim"
                )

            output_dim = output_dim_tmp

        self.output_dim = output_dim

        if projection_dim:
            self._projection = torch.nn.Linear(output_dim, projection_dim)
            self.output_dim = projection_dim
        else:
            self._projection = None

        for lang in weight_files.keys():
            name = "aligning_%s" % lang
            aligning_matrix = torch.eye(output_dim)
            if lang in aligning_files and aligning_files[lang] != "":
                aligninig_path = cached_path(aligning_files[lang])
                aligning_matrix = torch.FloatTensor(torch.load(aligninig_path))

            aligning = torch.nn.Linear(output_dim, output_dim, bias=False)
            aligning.weight = torch.nn.Parameter(aligning_matrix, requires_grad=False)
            self.add_module(name, aligning)
Example #3
0
class ELMo(nn.Module):
    """
    Proxy for ELMo embeddings.
    https://github.com/allenai/allennlp/blob/master/tutorials/how_to/elmo.md
    Args:
        embedding_dropout float: value of dropout
        options_file str: path for local elmo options file
        weights_file str: path for local elmo wights file
    """
    def __init__(self,
                 embedding_dropout=0.,
                 options_file=default_options_file,
                 weights_file=default_weights_file,
                 **kwargs):
        super(ELMo, self).__init__()

        self.dropout = nn.Dropout(p=embedding_dropout)
        self.elmo = Elmo(options_file,
                         weights_file,
                         num_output_representations=1).to(device)
        self.embedding_dim = self.elmo.get_output_dim()

    def forward(self, sentences):
        char_ids = batch_to_ids(sentences).to(device)

        embedded = self.elmo(char_ids)

        embeddings = self.dropout(embedded['elmo_representations'][0])
        mask = embedded['mask']

        return embeddings, mask
class EmbeddingsElmo(Module):

    # elmo_path = {small, medium, original}
    def __init__(self, elmo_path: str, input_vocabulary: List[str],
                 clear_text: bool):
        super().__init__()
        from allennlp.modules.elmo import Elmo
        if elmo_path in _elmo_models_map:
            options_file_path, weights_file_path = _elmo_models_map[elmo_path]
        else:
            options_file_path, weights_file_path = elmo_path + "_options.json", elmo_path + "_weights.hdf5"
        self.elmo_embeddings = Elmo(options_file=options_file_path,
                                    weight_file=weights_file_path,
                                    num_output_representations=1,
                                    vocab_to_cache=input_vocabulary)
        self.clear_text = clear_text

    # input:
    #   - sample_x:  Union[List[str], LongTensor]  - seq_in
    # output:
    #   - sample_x:  Union[List[str], LongTensor]  - seq_out
    #   - new_size:  int                           - seq_out
    #   - indices:   List[int]                     - seq_in
    @staticmethod
    def preprocess_sample_first(sample_x):
        return sample_x, None, None

    # input:
    #   - sample_x:  Union[List[str], LongTensor]  - seq_in
    #   - new_size:  int                           - seq_out
    #   - indices:   List[int]                     - seq_in
    # output:
    #   - sample_x:  Union[List[str], LongTensor]  - seq_out
    @staticmethod
    def preprocess_sample_next(sample_x, new_size, indices):
        return sample_x

    # inputs:
    #   - inputs:        Union[List[List[str]], LongTensor]  (batch x seq_in)
    # output:
    #   - output:        FloatTensor                         (batch x seq_out x hidden)
    #   - pad_mask:      LongTensor                          (batch x seq_out)
    #   - token_indices: List[List[int]]                     (batch x seq_in)
    def forward(self, inputs):
        if self.clear_text:
            from allennlp.modules.elmo import batch_to_ids
            inputs = batch_to_ids(inputs)
            inputs = inputs.to(default_device)
            return self.elmo_embeddings(
                inputs)["elmo_representations"][0], None, None
        else:
            return self.elmo_embeddings(
                inputs, inputs)["elmo_representations"][0], inputs, None

    def get_output_dim(self):
        return self.elmo_embeddings.get_output_dim()

    @staticmethod
    def is_fixed():
        return True
Example #5
0
class SentenceElmo(nn.Module):
    def __init__(self,
                 options_file,
                 weight_file,
                 tokenizer,
                 average_mod='mean',
                 max_seq_length=128):
        super().__init__()
        assert average_mod in {'mean', 'max', 'last'}

        self.elmo = Elmo(options_file=options_file,
                         weight_file=weight_file,
                         num_output_representations=1,
                         requires_grad=True)

        self.tokenizer = tokenizer
        self.average_mod = average_mod
        self.max_seq_length = max_seq_length

    def get_word_embedding_dimension(self) -> int:
        return self.elmo.get_output_dim()

    def forward(self, features):
        output = self.elmo(features['input_ids'])
        token_embeddings = output['elmo_representations'][0]

        features = {}
        if self.average_mod == 'mean':
            features['sentence_embedding'] = token_embeddings.mean(axis=1)
        elif self.average_mod == 'max':
            features['sentence_embedding'] = token_embeddings.max(
                axis=1).values
        else:
            last_token_indices = output['mask'].sum(axis=1) - 1
            features['sentence_embedding'] = token_embeddings[
                torch.arange(token_embeddings.shape[0]), last_token_indices, :]

        return features

    def tokenize(self, texts: List[str]):
        tokenized_texts = [
            self.tokenizer.tokenize(text)[:self.max_seq_length]
            for text in texts
        ]
        input_ids = batch_to_ids(tokenized_texts)

        output = {'input_ids': input_ids}
        return output

    def save(self, output_path: str):
        torch.save(self.elmo.state_dict(),
                   os.path.join(output_path, 'model.pth'))
Example #6
0
class Embedding(nn.Module):
    def __init__(self,
                 char_vocab_size,
                 glove_vocab_size,
                 word_vocab_size,
                 embed_dim,
                 dropout,
                 elmo=False,
                 elmo_options_file=None,
                 elmo_weights_file=None,
                 glove_cpu=False):
        super(Embedding, self).__init__()
        self.word_embedding = WordEmbedding(word_vocab_size, embed_dim)
        self.char_embedding = CharEmbedding(char_vocab_size, embed_dim)
        self.glove_embedding = WordEmbedding(glove_vocab_size,
                                             embed_dim,
                                             requires_grad=False,
                                             cpu=glove_cpu)
        self.output_size = 2 * embed_dim
        self.highway1 = Highway(self.output_size, dropout)
        self.highway2 = Highway(self.output_size, dropout)
        if elmo:
            assert elmo_options_file is not None and elmo_weights_file is not None
            from allennlp.modules.elmo import Elmo
            self.elmo = Elmo(elmo_options_file,
                             elmo_weights_file,
                             1,
                             dropout=0)
            self.output_size += self.elmo.get_output_dim()
        else:
            self.elmo = None

    def load_glove(self, glove_emb_mat):
        device = self.glove_embedding.embedding.weight.device
        glove_emb_mat = glove_emb_mat.to(device)
        glove_emb_mat = torch.cat([
            torch.zeros(2,
                        glove_emb_mat.size()[-1]).to(device), glove_emb_mat
        ],
                                  dim=0)
        self.glove_embedding.embedding.weight = torch.nn.Parameter(
            glove_emb_mat, requires_grad=False)

    def forward(self, cx, gx, x, ex=None):
        cx = self.char_embedding(cx)
        gx = self.glove_embedding(gx)
        output = torch.cat([cx, gx], -1)
        output = self.highway2(self.highway1(output))
        if self.elmo is not None:
            elmo, = self.elmo(ex)['elmo_representations']
            output = torch.cat([output, elmo], 2)
        return output
Example #7
0
def train_on(dataset, params):
    print("Using hyperparameter configuration:", params)

    losses = []
    state_dicts = []
    kfold = StratifiedKFold(dataset, k=10, grouping=origin_of)

    for train, val in kfold:
        # TODO: Figure how much of the following code we can put outside the loop

        vocab = Vocabulary.from_instances(dataset)
        # TODO: Figure out the best parameters here
        elmo = Elmo(cached_path(OPTIONS_FILE),
                    cached_path(WEIGHTS_FILE),
                    num_output_representations=2,
                    dropout=params["dropout"]
                    )  # TODO: Does dropout refer to the LSTM or ELMo?
        word_embeddings = ELMoTextFieldEmbedder({"tokens": elmo})
        # TODO: Figure out the best parameters here
        lstm = PytorchSeq2VecWrapper(
            torch.nn.LSTM(input_size=elmo.get_output_dim(),
                          hidden_size=64,
                          num_layers=params["num_layers"],
                          batch_first=True))

        model = RuseModel(word_embeddings, lstm, vocab)
        optimizer = optim.Adam(model.parameters())
        # TODO: What kind of iterator should be used?
        iterator = BucketIterator(batch_size=params["batch_size"],
                                  sorting_keys=[("mt_sent", "num_tokens"),
                                                ("ref_sent", "num_tokens")])
        iterator.index_with(vocab)

        # TODO: Figure out best hyperparameters
        trainer = Trainer(model=model,
                          optimizer=optimizer,
                          iterator=iterator,
                          cuda_device=0,
                          train_dataset=train,
                          validation_dataset=val,
                          patience=5,
                          num_epochs=100)
        trainer.train()
        # TODO: Better way to access the validation loss?
        loss, _ = trainer._validation_loss()
        losses.append(loss)
        state_dicts.append(model.state_dict())

    mean_loss = np.mean(losses)
    print("Mean validation loss was:", mean_loss)

    return TrainResults(cv_loss=mean_loss, state_dicts=state_dicts)
Example #8
0
    class ELMoWordEncoder(torch.nn.Module):
        def __init__(self, options_file: str, weight_file: str):
            super().__init__()
            self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
            self.out_dim = self.elmo.get_output_dim()

        def forward(
                self,
                character_ids: ty.Sequence[torch.Tensor]) -> WordEncoderOutput:
            # FIXME: this should be dealt with in digitize/collate (or should it ?)
            padded_characters = pad_sequence(
                [c.squeeze(0) for c in character_ids], batch_first=True)
            embeddings = self.elmo(padded_characters)
            seq_lens = embeddings["mask"].sum(dim=-1)
            return WordEncoderOutput(embeddings["elmo_representations"][0],
                                     seq_lens)
Example #9
0
class ELMo(nn.Module):
    def __init__(self):

        super(ELMo, self).__init__()
        options_path = settings.PATH_TO_ELMO_OPTIONS
        weights_path = settings.PATH_TO_ELMO_WEIGHTS

        self.embedding = Elmo(
            options_path, 
            weights_path,
            num_output_representations=1,
            dropout=0.5,
            scalar_mix_parameters=[-9e10, 1, -9e10]
        )
        # scalar_mix_parameters=[-9e10, -9e10, 1]
        self.out_ftrs = self.embedding.get_output_dim()

    def forward(self, x):
        x = self.embedding(x)
        masks = x["mask"].float()
        x = x["elmo_representations"][0]
        return {"representation": x, "masks": masks}
Example #10
0
class CompressDecoder(torch.nn.Module):
    def __init__(self,
                 context_dim,
                 dec_state_dim,
                 enc_hid_dim,
                 text_field_embedder,
                 aggressive_compression: int = -1,
                 keep_threshold: float = 0.5,
                 abs_board_file="/home/cc/exComp/board.txt",
                 gather='mean',
                 dropout=0.5,
                 dropout_emb=0.2,
                 valid_tmp_path='/scratch/cluster/jcxu/exComp',
                 serilization_name: str = "",
                 vocab=None,
                 elmo: bool = False,
                 elmo_weight: str = "elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"):
        super().__init__()
        self.use_elmo = elmo
        self.serilization_name = serilization_name
        if elmo:
            from allennlp.modules.elmo import Elmo, batch_to_ids
            from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
            self.vocab = vocab

            options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
            weight_file = elmo_weight
            self.elmo = Elmo(options_file, weight_file, 1, dropout=dropout_emb)
            # print(self.elmo.get_output_dim())
            # self.word_emb_dim = text_field_embedder.get_output_dim()
            # self._context_layer = PytorchSeq2SeqWrapper(
            #     torch.nn.LSTM(self.word_emb_dim + self.elmo.get_output_dim(), self.word_emb_dim,
            #                   batch_first=True, bidirectional=True))
            self.word_emb_dim = self.elmo.get_output_dim()
        else:
            self._text_field_embedder = text_field_embedder
            self.word_emb_dim = text_field_embedder.get_output_dim()

        self.XEloss = torch.nn.CrossEntropyLoss(reduction='none')
        self.device = get_device()

        # self.rouge_metrics_compression = RougeStrEvaluation(name='cp', path_to_valid=valid_tmp_path,
        #                                                     writting_address=valid_tmp_path,
        #                                                     serilization_name=serilization_name)
        # self.rouge_metrics_compression_best_possible = RougeStrEvaluation(name='cp_ub', path_to_valid=valid_tmp_path,
        #                                                                   writting_address=valid_tmp_path,
        #                                                                   serilization_name=serilization_name)
        self.enc = EncCompression(inp_dim=self.word_emb_dim, hid_dim=enc_hid_dim, gather=gather)  # TODO dropout

        self.aggressive_compression = aggressive_compression
        self.relu = torch.nn.ReLU()

        self.attn = NewAttention(enc_dim=self.enc.get_output_dim(),
                                 dec_dim=self.enc.get_output_dim_unit() * 2 + dec_state_dim)

        self.concat_size = self.enc.get_output_dim() + self.enc.get_output_dim_unit() * 2 + dec_state_dim
        self.valid_tmp_path = valid_tmp_path
        if self.aggressive_compression < 0:
            self.XELoss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=-1)
            # self.nn_lin = torch.nn.Linear(self.concat_size, self.concat_size)
            # self.nn_lin2 = torch.nn.Linear(self.concat_size, 2)

            self.ff = FeedForward(input_dim=self.concat_size, num_layers=3,
                                  hidden_dims=[self.concat_size, self.concat_size, 2],
                                  activations=[torch.nn.Tanh(), torch.nn.Tanh(), lambda x: x],
                                  dropout=dropout
                                  )
            # Keep thresold

            # self.keep_thres = list(np.arange(start=0.2, stop=0.6, step=0.075))
            self.keep_thres = [0.0, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 1.0]
            self.rouge_metrics_compression_dict = OrderedDict()
            for thres in self.keep_thres:
                self.rouge_metrics_compression_dict["{}".format(thres)] = RougeStrEvaluation(name='cp_{}'.format(thres),
                                                                                             path_to_valid=valid_tmp_path,
                                                                                             writting_address=valid_tmp_path,
                                                                                             serilization_name=serilization_name)

    def encode_sent_and_span_paral(self, text,  # batch, max_sent, max_word
                                   text_msk,  # batch, max_sent, max_word
                                   span,  # batch, max_sent_num, max_span_num, max_word
                                   sent_idx  # batch size
                                   ):
        this_text = two_dim_index_select(text['tokens'], sent_idx)  # batch, max_word
        from allennlp.modules.elmo import batch_to_ids
        if self.use_elmo:
            this_text_list: List = this_text.tolist()
            text_str_list = []
            for sample in this_text_list:
                s = [self.vocab.get_token_from_index(x) for x in sample]
                text_str_list.append(s)
            character_ids = batch_to_ids(text_str_list).to(self.device)
            this_context = self.elmo(character_ids)
            # print(this_context['elmo_representations'][0].size())
            this_context = this_context['elmo_representations'][0]
        else:
            this_text = {'tokens': this_text}
            this_context = self._text_field_embedder(this_text)

        num_doc, max_word, inp_dim = this_context.size()
        batch_size = sent_idx.size()[0]
        assert batch_size == num_doc

        # text is the original text of the selected sentence.
        # this_context = two_dim_index_select(context, sent_idx)  # batch, max_word, hdim
        this_context_mask = two_dim_index_select(text_msk, sent_idx)  # batch, max_word
        this_span = two_dim_index_select(span, sent_idx)  # batch , nspan, max_word

        concat_rep_of_compression, \
        span_msk, original_sent_rep = self.enc.forward(word_emb=this_context,
                                                       word_emb_msk=this_context_mask,
                                                       span=this_span)
        return concat_rep_of_compression, span_msk, original_sent_rep

    def encode_sent_and_span(self, text, text_msk, span, batch_idx, sent_idx):
        context = self._text_field_embedder(text)
        num_doc, max_sent, max_word, inp_dim = context.size()
        num_doc_, max_sent_, nspan = span.size()[0:-1]
        assert num_doc == num_doc_
        assert max_sent == max_sent_
        this_context = context[batch_idx, sent_idx, :, :].unsqueeze(0)
        this_span = span[batch_idx, sent_idx, :, :].unsqueeze(0)
        this_context_mask = text_msk[batch_idx, sent_idx, :].unsqueeze(0)
        flattened_enc, attn_dist, \
        spans_rep, span_msk, score \
            = self.enc.forward(word_emb=this_context,
                               word_emb_msk=this_context_mask,
                               span=this_span)
        return flattened_enc, spans_rep, span_msk
        # 1, hid*2      1, span num, hid        1, span num

    def indep_compression_judger(self, reps):
        # t, batch_size_, max_span_num,self.concat_size
        timestep, batch_size, max_span_num, dim = reps.size()
        score = self.ff.forward(reps)
        # lin_out = self.nn_lin(reps)
        # activated = torch.sigmoid(lin_out)
        # score = self.nn_lin2(activated)
        if random.random() < 0.005:
            print("score: {}".format(score[0]))
        return score

    def get_out_dim(self):
        return self.concat_size

    def forward_parallel(self, sent_decoder_states,  # t, batch, hdim
                         sent_decoder_outputs_logit,  # t, batch
                         document_rep,  # batch, hdim
                         text,  # batch, max_sent, max_word
                         text_msk,  # batch, max_sent, max_word
                         span):  # batch, max_sent_num, max_span_num, max_word
        # Encode compression options given sent emission.
        # output scores, attn dist, ...
        t, batch_size, hdim = sent_decoder_states.size()
        t_, batch_size_ = sent_decoder_outputs_logit.size()  # invalid bits are -1
        batch, max_sent, max_span_num, max_word = span.size()
        # assert t == t_
        t = min(t, t_)
        assert batch_size == batch == batch_size_
        if self.aggressive_compression > 0:
            all_attn_dist = torch.zeros((t, batch_size, max_span_num)).to(self.device)
            all_scores = torch.ones((t, batch_size, max_span_num)).to(self.device) * -100
        else:
            all_attn_dist = None
            all_scores = None
        all_reps = torch.zeros((t, batch_size_, max_span_num, self.concat_size), device=self.device)
        for timestep in range(t):
            dec_state = sent_decoder_states[timestep]  # batch, dim
            logit = sent_decoder_outputs_logit[timestep]  # batch

            # valid_mask = (logit > 0)
            positive_logit = self.relu(logit.float()).long()  # turn -1 to 0

            span_t, span_msk_t, sent_t = self.encode_sent_and_span_paral(text=text,
                                                                         text_msk=text_msk,
                                                                         span=span,
                                                                         sent_idx=positive_logit)
            # sent_t : batch, sent_dim
            # span_t: batch, span_num, span_dim
            # span_msk_t: batch, span_num [[1,1,1,0,0,0],

            concated_rep_high_level = torch.cat([dec_state, document_rep, sent_t], dim=1)
            # batch, DIM
            if self.aggressive_compression > 0:
                attn_dist, score = self.attn.forward_one_step(enc_state=span_t,
                                                              dec_state=concated_rep_high_level,
                                                              enc_mask=span_msk_t.float())
            # attn_dist: batch, span num
            # score:    batch, span num

            # concated_rep: batch, dim ==> batch, 1, dim ==> batch, max_span_num, dim
            expanded_concated_rep = concated_rep_high_level.unsqueeze(1).expand((batch, max_span_num, -1))
            all_reps[timestep, :, :, :] = torch.cat([expanded_concated_rep, span_t], dim=2)
            if self.aggressive_compression > 0:
                all_attn_dist[timestep, :, :] = attn_dist
                all_scores[timestep, :, :] = score

        return all_attn_dist, all_scores, all_reps

    def comp_loss_inf_deletion(self,
                               decoder_outputs_logit,  # gold label!!!!
                               # span_seq_label,  # batch, max sent num
                               span_rouge,  # batch, max sent num, max compression num
                               scores,
                               comp_rouge_ratio,
                               loss_thres=1
                               ):
        """

        :param decoder_outputs_logit:
        :param span_rouge: [batch, max_sent, max_compression]
        :param scores: [timestep, batch, max_compression, 2]
        :param comp_rouge_ratio: [batch_size, max_sent, max_compression]
        :return:
        """
        tim, bat = decoder_outputs_logit.size()
        time, batch, max_span, _ = scores.size()
        batch_, sent_len, max_sp = span_rouge.size()
        assert batch_ == batch == bat
        assert time == tim
        assert max_sp == max_span
        goal_rouge_label = torch.ones((tim, batch, max_span), device=self.device, dtype=torch.long,
                                      ) * (-1)
        weights = torch.ones((tim, batch, max_span), device=self.device, dtype=torch.float)
        decoder_outputs_logit_mask = (decoder_outputs_logit >= 0).unsqueeze(2).expand(
            (time, batch, max_span)).float().view(-1)
        decoder_outputs_logit = torch.nn.functional.relu(decoder_outputs_logit).long()
        z = torch.zeros((1), device=self.device)
        for tt in range(tim):
            decoder_outputs_logit_t = decoder_outputs_logit[tt]
            out = two_dim_index_select(inp=comp_rouge_ratio, index=decoder_outputs_logit_t)
            label = torch.gt(out, loss_thres).long()

            mini_mask = torch.gt(out, 0.01).float()

            # baseline_mask = 1 - torch.lt(torch.abs(out - 0.99), 0.01).float()  # baseline will be 0

            # weight = torch.max(input=-out + 0.5, other=z) + 1
            # weights[tt] = mini_mask * baseline_mask
            weights[tt] = mini_mask
            goal_rouge_label[tt] = label
        probs = scores.view(-1, 2)
        goal_rouge_label = goal_rouge_label.view(-1)
        weights = weights.view(-1)
        loss = self.XELoss(input=probs, target=goal_rouge_label)
        loss = loss * decoder_outputs_logit_mask * weights
        return torch.mean(loss)

    def comp_loss(self, decoder_outputs_logit,  # gold label!!!!
                  scores,
                  span_seq_label,  # batch, max sent num
                  span_rouge,  # batch, max sent num, max compression num
                  comp_rouge_ratio
                  ):
        t, batch = decoder_outputs_logit.size()
        t_, batch_, comp_num = scores.size()
        b, max_sent = span_seq_label.size()
        # b_, max_sen, max_comp_, _ = span.size()
        _b, max_sent_, max_comp = span_rouge.size()
        assert batch == batch_ == b == _b
        assert max_sent_ == max_sent
        assert comp_num == max_comp
        span_seq_label = span_seq_label.long()
        total_loss = torch.zeros((t, b)).to(self.device)
        # print(decoder_outputs_logit)
        # print(span_seq_label)
        for timestep in range(t):

            # this is the sent idx
            for batch_idx in range(b):
                logit = decoder_outputs_logit[timestep][batch_idx]
                # print(logit)
                # decoder_outputs_logit should be the gold label for sentence emission.
                # if it's 0 or -1, then we skip supervision.
                if logit < 0:
                    continue
                ref_rouge_score = comp_rouge_ratio[batch_idx][logit]
                num_of_compression = ref_rouge_score.size()[0]

                _supervision_label_msk = (ref_rouge_score > 0.98).float()
                label = torch.from_numpy(np.arange(num_of_compression)).to(self.device).long()
                score_t = scores[timestep][batch_idx].unsqueeze(0)  # comp num
                score_t = score_t.expand(num_of_compression, -1)
                # label = span_seq_label[batch_idx][logit].unsqueeze(0)

                loss = self.XEloss(score_t, label)
                # print(loss)
                loss = _supervision_label_msk * loss
                total_loss[timestep][batch_idx] = torch.sum(loss)
                # sent_msk_t = two_dim_index_select(sent_mask, logit)

        return torch.mean(total_loss)

    def _dec_compression_one_step(self, predict_compression,
                                  sp_meta,
                                  word_sent: List[str], keep_threshold: List[float],
                                  context: List[List[str]] = None):

        full_set_len = set(range(len(word_sent)))
        # max_comp, _ = predict_compression.size

        preds = [full_set_len.copy() for _ in range(len(keep_threshold))]

        # Show all of the compression spans
        stat_compression = {}
        for comp_idx, comp_meta in enumerate(sp_meta):
            p = predict_compression[comp_idx][1]
            node_type, sel_idx, rouge, ratio = comp_meta
            if node_type != "BASELINE":
                selected_words = [x for idx, x in enumerate(word_sent) if idx in sel_idx]
                selected_words_str = "_".join(selected_words)
                stat_compression["{}".format(selected_words_str)] = {
                    "prob": float("{0:.2f}".format(p)),  # float("{0:.2f}".format())
                    "type": node_type,
                    "rouge": float("{0:.2f}".format(rouge)),
                    "ratio": float("{0:.2f}".format(ratio)),
                    "sel_idx": sel_idx,
                    "len": len(sel_idx)
                }
        stat_compression_order = OrderedDict(
            sorted(stat_compression.items(), key=lambda item: item[1]["prob"], reverse=True))  # Python 3
        for idx, _keep_thres in enumerate(keep_threshold):
            history: List[str] = context[idx]
            his_set = set((" ".join(history)).split(" "))
            for key, value in stat_compression_order.items():
                p = value['prob']
                sel_idx = value['sel_idx']
                sel_txt = set([word_sent[x] for x in sel_idx])
                if sel_txt - his_set == set():
                    # print("Save big!")
                    # print("Context: {}\tCandidate: {}".format(his_set, sel_txt))
                    preds[idx] = preds[idx] - set(value['sel_idx'])
                    continue
                if p > _keep_thres:
                    preds[idx] = preds[idx] - set(value['sel_idx'])

        preds = [list(x) for x in preds]
        for pred in preds:
            pred.sort()
        # Visual output
        visual_outputs: List[str] = []
        words_for_evaluation: List[str] = []
        meta_keep_ratio_word = []

        for idx, compression in enumerate(preds):
            output = [word_sent[jdx] if (jdx in compression) else '_' + word_sent[jdx] + '_' for jdx in
                      range(len(word_sent))]
            visual_outputs.append(" ".join(output))

            words = [word_sent[x] for x in compression]
            meta_keep_ratio_word.append(float(len(words) / len(word_sent)))
            # meta_kepp_ratio_span.append(1 - float(len(survery['type'][idx]) / len(sp_meta)))
            words = " ".join(words)
            words = easy_post_processing(words)
            # print(words)
            words_for_evaluation.append(words)
        d: List[List] = []
        for kep_th, vis, words_eva, keep_word_ratio in zip(keep_threshold, visual_outputs, words_for_evaluation,
                                                           meta_keep_ratio_word):
            d.append([kep_th, vis, words_eva, keep_word_ratio])
        return stat_compression_order, d

    def decode_inf_deletion(self,
                            sent_decoder_outputs_logit,  # time, batch
                            span_prob,  # time, batch, max_comp, 2
                            metadata: List,
                            span_meta: List,
                            span_rouge,  # batch, sent, max_comp
                            keep_threshold: List[float]
                            ):
        batch_size, max_sent_num, max_comp_num = span_rouge.size()
        t, batsz, max_comp, _ = span_prob.size()
        span_score = torch.nn.functional.softmax(span_prob, dim=3).cpu().numpy()
        timestep, batch = sent_decoder_outputs_logit.size()
        sent_decoder_outputs_logit = sent_decoder_outputs_logit.cpu().data

        for idx, m in enumerate(metadata):
            abs_s = [" ".join(s) for s in m["abs_list"]]
            comp_exe = CompExecutor(span_meta=span_meta[idx],
                                    sent_idxs=sent_decoder_outputs_logit[:, idx],
                                    prediction_score=span_score[:, idx, :, :],
                                    abs_str=abs_s,
                                    name=m['name'],
                                    doc_list=m["doc_list"],
                                    keep_threshold=keep_threshold,
                                    part=m['name'], ser_dir=self.valid_tmp_path,
                                    ser_fname=self.serilization_name
                                    )
            # processed_words, del_record, \
            # compressions, full_sents, \
            bag_pred_eval = comp_exe.run()
            full_sents: List[List[str]] = comp_exe.full_sents
            # assemble full sents
            full_sents = [" ".join(x) for x in full_sents]

            # visual to console
            for idx in range(len(keep_threshold)):
                self.rouge_metrics_compression_dict["{}".format(keep_threshold[idx])](pred=bag_pred_eval[idx],
                                                                                      ref=[abs_s], origin=full_sents
                                                                                      )
Example #11
0
class LstmCnnOutElmo(nn.Module):
    def __init__(self,
                 vocabs,
                 elmo_option,
                 elmo_weight,
                 lstm_hidden_size,
                 lstm_dropout=0.5,
                 feat_dropout=0.5,
                 elmo_dropout=0.5,
                 parameters=None,
                 output_bias=True,
                 elmo_finetune=False,
                 tag_scheme='bioes'):
        super(LstmCnnOutElmo, self).__init__()

        self.vocabs = vocabs
        self.label_size = len(self.vocabs['label'])
        # input features
        self.word_embed = nn.Embedding(parameters['word_embed_num'],
                                       parameters['word_embed_dim'],
                                       padding_idx=C.PAD_INDEX)

        self.elmo = Elmo(elmo_option,
                         elmo_weight,
                         num_output_representations=1,
                         requires_grad=elmo_finetune,
                         dropout=elmo_dropout)
        self.elmo_dim = self.elmo.get_output_dim()

        self.word_dim = self.word_embed.embedding_dim
        self.feat_dim = self.word_dim
        # layers
        self.lstm = LSTM(input_size=self.feat_dim,
                         hidden_size=lstm_hidden_size,
                         batch_first=True,
                         bidirectional=True)
        self.output_linear = Linear(self.lstm.output_size + self.elmo_dim,
                                    self.label_size,
                                    bias=output_bias)
        self.crf = CRF(vocabs['label'], tag_scheme=tag_scheme)
        self.feat_dropout = nn.Dropout(p=feat_dropout)
        self.lstm_dropout = nn.Dropout(p=lstm_dropout)

    def forward_nn(self, token_ids, elmo_ids, lens, return_hidden=False):
        # word representation
        word_in = self.word_embed(token_ids)
        feats = self.feat_dropout(word_in)

        # LSTM layer
        lstm_in = R.pack_padded_sequence(feats,
                                         lens.tolist(),
                                         batch_first=True)
        lstm_out, _ = self.lstm(lstm_in)
        lstm_out, _ = R.pad_packed_sequence(lstm_out, batch_first=True)
        lstm_out = self.lstm_dropout(lstm_out)

        # Elmo output
        elmo_out = self.elmo(elmo_ids)['elmo_representations'][0]
        combined_out = torch.cat([lstm_out, elmo_out], dim=2)

        # output linear layer
        linear_out = self.output_linear(combined_out)
        if return_hidden:
            return linear_out, combined_out.tolist()
        else:
            return linear_out, None

    def predict(self,
                token_ids,
                elmo_ids,
                lens,
                return_hidden=False,
                return_conf_score=False):
        self.eval()
        logits, lstm_out = self.forward_nn(token_ids,
                                           elmo_ids,
                                           lens,
                                           return_hidden=return_hidden)
        logits_padded = self.crf.pad_logits(logits)
        _scores, preds = self.crf.viterbi_decode(logits_padded, lens)
        if return_conf_score:
            conf_score = self.crf.calc_conf_score(logits, preds)
        else:
            conf_score = None
        preds = preds.data.tolist()
        self.train()
        return preds, lstm_out, conf_score
Example #12
0
class BiLSTM(nn.Module):
    def __init__(self, config: Dict):
        super(BiLSTM, self).__init__()
        self.bilstm = nn.LSTM(input_size=config['embed_dim'],
                              hidden_size=config['hidden_size'],
                              bidirectional=True,
                              batch_first=True,
                              num_layers=config['rnn_layers'],
                              bias=True)
        self.bn1 = nn.BatchNorm1d(2 * config['hidden_size'])
        self.a1 = nn.SELU()
        # concat hidden states of two directions

        self.linear = nn.Linear(2 * config['hidden_size'], config['num_cat'])
        print(f"tag vocabulary size: {config['num_cat']}")
        self.config = config
        if config['method'] == 'elmo':
            self.init_elmo()
        elif config['method'] == 'glove':
            self.init_glove()
        elif config['method'] == 'bert':
            self.init_bert()
        else:
            raise Exception(
                f'the method must be one of the following: {methods}')

    def forward(self, x):
        x = self.get_embedding(x)
        x = pack_sequence(x, enforce_sorted=False)
        packed_output, (h, c) = self.bilstm(x)
        x, each_len = pad_packed_sequence(packed_output)
        example_len = each_len[0]
        x = x.permute(1, 0, 2)
        # x of size (batch_size, max_len, 2 * hidden_size)
        x = torch.cat([x[i][:l] for i, l in enumerate(each_len)], dim=0)
        # reshape to (num_of_token, 2 * hidden_size)
        x = self.bn1(x)
        x = self.a1(x)
        x = self.linear(x)
        # x of size (num_of_token, num_cat)
        return x, each_len

    def init_bert(self):
        bert_shortcut = self.config['bert_shortcut']
        from pytorch_pretrained_bert import BertTokenizer, BertModel
        self.tokenizer = BertTokenizer.from_pretrained(bert_shortcut)
        self.bert_model = BertModel.from_pretrained(bert_shortcut).to(device)
        # self.bert_model.eval()
        self.bert_model.train()

    def get_bert(self, sentences: List[List[str]]):
        sentences = [' '.join(s) for s in sentences]
        tokenized_sentences = [self.tokenizer.tokenize(s) for s in sentences]
        sentences = [s.split(' ') for s in sentences]
        grouping_list = [
            group_subword(ts, s)
            for ts, s in zip(tokenized_sentences, sentences)
        ]

        # padding with the default [PAD]
        max_len = max([len(ts) for ts in tokenized_sentences])
        tokenized_sentences = [
            ts + ['[PAD]'] * (max_len - len(ts)) for ts in tokenized_sentences
        ]

        indexed_sentences = [
            self.tokenizer.convert_tokens_to_ids(s)
            for s in tokenized_sentences
        ]
        idx_tensors = torch.tensor(indexed_sentences).to(device)
        with torch.no_grad():
            try:
                encoded_layers, _ = self.bert_model(idx_tensors)
                # compute the sum of last 4 layers
                token_embeddings = torch.stack(encoded_layers[-4:], dim=0)
                token_embeddings = token_embeddings.permute(1, 2, 0, 3)
                # token_embeddings of size (num_sentences, max_len, 4, num_of_hidden_features)
                # sum the last 4 layers to get the token embeddings
                token_embeddings = torch.sum(token_embeddings, dim=2)
                # print(token_embeddings.size())
                # print(token_embeddings)
                return [
                    group_embeddings(g, te)
                    for g, te in zip(grouping_list, token_embeddings)
                ]
            except RuntimeError as e:
                print(e)
                # print('\n'.join([' '.join(s) for s in sentences]))
                print(idx_tensors)
                import sys
                sys.exit(0)

    def init_glove(self):
        self.word2id = np.load(self.config['word2id_path'],
                               allow_pickle=True).tolist()
        glove_embeddings = torch.from_numpy(
            np.load(self.config['glove_path'], allow_pickle=True))
        glove_embeddings = glove_embeddings.to(device)
        self.glove = nn.Embedding(self.config['vocab_size'],
                                  self.config['embed_dim'])
        self.glove.weight.data.copy_(glove_embeddings)
        self.embed_dim = self.config['embed_dim']

    def get_glove(self, sentences: List[List[str]]):
        max_len = max(map(lambda x: len(x), sentence_lists))
        sentence_lists = list(
            map(lambda x: list(map(lambda w: self.word2id.get(w, 0), x)),
                sentence_lists))
        sentence_lists = list(
            map(lambda x: x + [self.opt.vocab_size - 1] * (max_len - len(x)),
                sentence_lists))
        sentence_lists = torch.LongTensor(sentence_lists).to(device)
        embeddings = self.glove(sentence_lists)
        # pack
        return embeddings

    def init_elmo(self):
        from allennlp.modules.elmo import Elmo, batch_to_ids
        self.elmo = Elmo(self.config['elmo_options_file'],
                         self.config['elmo_weights_file'], 1)
        self.embed_dim = self.elmo.get_output_dim()

    def get_elmo(self, sentences: List[List[str]]):
        # sentences = [s.split(' ') for s in sentences]
        character_ids = batch_to_ids(sentences).to(device)
        embeddings = self.elmo(character_ids)['elmo_representations'][0]
        # pack
        return embeddings

    def get_embedding(self, sentences: List[List[str]]):
        if self.config['method'] == 'elmo':
            return self.get_elmo(sentences)
        elif self.config['method'] == 'glove':
            return self.get_glove(sentences)
        elif self.config['method'] == 'bert':
            return self.get_bert(sentences)
        else:
            raise SimpleCCGException(
                f'the method must be one of the following {methods}')
class Encoder(nn.Module):
    def create_ner_embed(self, opt):
        ner_vocab_size = opt['ner_vocab_size']
        ner_embed_dim = opt['ner_dim']
        self.ner_embedding = nn.Embedding(ner_vocab_size,
                                          ner_embed_dim,
                                          padding_idx=0)
        return ner_embed_dim

    def create_pos_embed(self, opt):
        pos_vocab_size = opt['pos_vocab_size']
        pos_embed_dim = opt['pos_dim']
        self.pos_embedding = nn.Embedding(pos_vocab_size,
                                          pos_embed_dim,
                                          padding_idx=0)
        return pos_embed_dim

    def create_word_embed(self, opt, embedding):
        vocab_size = opt['vocab_size']
        embed_dim = 300
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.embedding.weight.data = embedding
        fixed_embedding = embedding[opt['embed_tune_partial']:]
        self.register_buffer('fixed_embedding', fixed_embedding)
        self.fixed_embedding = fixed_embedding
        return embed_dim

    def create_elmo(self, opt):
        self.elmo = Elmo(opt['elmo_config_path'],
                         opt['elmo_weight_path'],
                         num_output_representations=3)
        return self.elmo.get_output_dim()

    def __init__(self, opt, embedding):
        super(Encoder, self).__init__()
        self.dropout_p = opt['dropout_p']

        # self.eval_embed; eval_embed.weight.data (model创建之后调用的)
        # self.embedding; self.embedding.weight.data, self.fixed_embedding, self.embedding_dim=300
        self.embedding_dim = self.create_word_embed(opt, embedding)
        # self.elmo, self.elmo_size=1024
        self.elmo_size = self.create_elmo(opt)

        self.lstm = nn.LSTM(
            self.embedding_dim + self.elmo_size,  # 300+1024
            opt['encoder_lstm_hidden_size'],  # 128
            num_layers=1,
            bidirectional=True,
            batch_first=True)

        self.output_size = 2 * opt[
            'encoder_lstm_hidden_size'] + self.elmo_size  # 128*2 +1024 = 1280

        # 手工特征
        pos_size = self.create_pos_embed(opt)  # self.pos_embedding, 18
        ner_size = self.create_ner_embed(opt)  # self.ner_embedding, 18
        feat_size = 4
        self.manual_fea_size = pos_size + ner_size + feat_size  # 40

    def forward(self, batch):
        doc_tok = Variable(batch['doc_tok'])
        doc_ctok = Variable(batch['doc_ctok'])
        doc_pos = Variable(batch['doc_pos'])
        doc_ner = Variable(batch['doc_ner'])
        doc_fea = Variable(batch['doc_fea'])
        query_tok = Variable(batch['query_tok'])
        query_ctok = Variable(batch['query_ctok'])

        emb = self.embedding if self.training else self.eval_embed
        doc_emb = emb(doc_tok)
        query_emb = emb(query_tok)

        doc_elmo = self.elmo(doc_ctok)['elmo_representations'][0]
        query_elmo = self.elmo(query_ctok)['elmo_representations'][0]

        doc_o, _ = self.lstm(torch.cat([doc_emb, doc_elmo],
                                       2))  # [batch, seq_len, 200]
        doc_o = nn.Dropout(self.dropout_p)(doc_o)
        U_P = torch.cat([doc_o, doc_elmo], 2)  # [batch, seq_len, 200+1024]

        query_o, _ = self.lstm(torch.cat([query_emb, query_elmo], 2))
        query_o = nn.Dropout(self.dropout_p)(query_o)
        U_Q = torch.cat([query_o, query_elmo], 2)

        doc_pos_emb = self.pos_embedding(doc_pos)
        doc_ner_emb = self.ner_embedding(doc_ner)
        doc_manual_feature = torch.cat([doc_pos_emb, doc_ner_emb, doc_fea], -1)

        return U_Q, U_P, doc_manual_feature
Example #14
0
class ElmoTokenEmbedderCached(TokenEmbedder):
    """
    Compute a single layer of ELMo representations.

    This class serves as a convenience when you only want to use one layer of
    ELMo representations at the input of your network.  It's essentially a wrapper
    around Elmo(num_output_representations=1, ...)

    Parameters
    ----------
    options_file : ``str``, required.
        An ELMo JSON options file.
    weight_file : ``str``, required.
        An ELMo hdf5 weight file.
    do_layer_norm : ``bool``, optional.
        Should we apply layer normalization (passed to ``ScalarMix``)?
    dropout : ``float``, optional.
        The dropout value to be applied to the ELMo representations.
    requires_grad : ``bool``, optional
        If True, compute gradient of ELMo parameters for fine tuning.
    projection_dim : ``int``, optional
        If given, we will project the ELMo embedding down to this dimension.  We recommend that you
        try using ELMo with a lot of dropout and no projection first, but we have found a few cases
        where projection helps (particularly where there is very limited training data).
    vocab_to_cache : ``List[str]``, optional, (default = 0.5).
        A list of words to pre-compute and cache character convolutions
        for. If you use this option, the ElmoTokenEmbedder expects that you pass word
        indices of shape (batch_size, timesteps) to forward, instead
        of character indices. If you use this option and pass a word which
        wasn't pre-cached, this will break.
    scalar_mix_parameters : ``List[int]``, optional, (default=None)
        If not ``None``, use these scalar mix parameters to weight the representations
        produced by different layers. These mixing weights are not updated during
        training.
    """
    def __init__(self,
                 options_file: str,
                 weight_file: str,
                 do_layer_norm: bool = False,
                 dropout: float = 0.5,
                 requires_grad: bool = False,
                 projection_dim: int = None,
                 vocab_to_cache: List[str] = None,
                 scalar_mix_parameters: List[float] = None) -> None:
        super(ElmoTokenEmbedderCached, self).__init__()
        self.cache = {}

        self._elmo = Elmo(options_file,
                          weight_file,
                          1,
                          do_layer_norm=do_layer_norm,
                          dropout=dropout,
                          requires_grad=requires_grad,
                          vocab_to_cache=vocab_to_cache,
                          scalar_mix_parameters=scalar_mix_parameters)
        if projection_dim:
            self._projection = torch.nn.Linear(self._elmo.get_output_dim(),
                                               projection_dim)
            self.output_dim = projection_dim
        else:
            self._projection = None
            self.output_dim = self._elmo.get_output_dim()

    def get_output_dim(self) -> int:
        return self.output_dim

    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: torch.Tensor,
            word_inputs: torch.Tensor = None,
            words=None) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``
            Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch.
        word_inputs : ``torch.Tensor``, optional.
            If you passed a cached vocab, you can in addition pass a tensor of shape
            ``(batch_size, timesteps)``, which represent word ids which have been pre-cached.

        Returns
        -------
        The ELMo representations for the input sequence, shape
        ``(batch_size, timesteps, embedding_dim)``
        """
        keys, lens = generate_keys(words, inputs)

        keys_in_cache = [True if key in self.cache else False for key in keys]
        all_keys_in_cache = True if False not in keys_in_cache else False

        if all_keys_in_cache:
            elmo_representations = self.retrieve_from_cache(keys)
            print("cache hit!", input_ids.size(), mix.size())

        else:
            elmo_output = self._elmo(inputs, word_inputs)
            elmo_representations = elmo_output['elmo_representations'][0]
            self.add_to_cache(elmo_representations, keys, lens)

        if self._projection:
            projection = self._projection
            for _ in range(elmo_representations.dim() - 2):
                projection = TimeDistributed(projection)
            elmo_representations = projection(elmo_representations)
        return elmo_representations

    # Custom vocab_to_cache logic requires a from_params implementation.
    @classmethod
    def from_params(
            cls, vocab: Vocabulary,
            params: Params) -> 'ElmoTokenEmbedderCached':  # type: ignore
        # pylint: disable=arguments-differ
        params.add_file_to_archive('options_file')
        params.add_file_to_archive('weight_file')
        options_file = params.pop('options_file')
        weight_file = params.pop('weight_file')
        requires_grad = params.pop('requires_grad', False)
        do_layer_norm = params.pop_bool('do_layer_norm', False)
        dropout = params.pop_float("dropout", 0.5)
        namespace_to_cache = params.pop("namespace_to_cache", None)
        if namespace_to_cache is not None:
            vocab_to_cache = list(
                vocab.get_token_to_index_vocabulary(namespace_to_cache).keys())
        else:
            vocab_to_cache = None
        projection_dim = params.pop_int("projection_dim", None)
        scalar_mix_parameters = params.pop('scalar_mix_parameters', None)
        params.assert_empty(cls.__name__)
        return cls(options_file=options_file,
                   weight_file=weight_file,
                   do_layer_norm=do_layer_norm,
                   dropout=dropout,
                   requires_grad=requires_grad,
                   projection_dim=projection_dim,
                   vocab_to_cache=vocab_to_cache,
                   scalar_mix_parameters=scalar_mix_parameters)

    def add_to_cache(self, embeddings, keys, lens):
        for key, length, embedding in zip(keys, lens, embeddings):
            if key not in self.cache:
                self.cache[key] = embedding[:length, :].to(torch.device("cpu"))
            #print("added", embedding.size(),self.cache[key].size(), len(key))

    #print(len(self.cache))

    def retrieve_from_cache(self, keys):
        embeddings = []
        for key in keys:
            embedding = self.cache[key].to(
                torch.device("cuda" if torch.cuda.is_available() else "cpu"))
            embeddings.append(embedding)
            #print("retrieved",embedding.size(),len(key))
        return pad_sequence(embeddings, batch_first=True)
Example #15
0
class ElmoTokenEmbedder(TokenEmbedder):
    """
    Compute a single layer of ELMo representations.

    This class serves as a convenience when you only want to use one layer of
    ELMo representations at the input of your network.  It's essentially a wrapper
    around Elmo(num_output_representations=1, ...)

    # Parameters

    options_file : `str`, required.
        An ELMo JSON options file.
    weight_file : `str`, required.
        An ELMo hdf5 weight file.
    do_layer_norm : `bool`, optional.
        Should we apply layer normalization (passed to `ScalarMix`)?
    dropout : `float`, optional, (default = 0.5).
        The dropout value to be applied to the ELMo representations.
    requires_grad : `bool`, optional
        If True, compute gradient of ELMo parameters for fine tuning.
    projection_dim : `int`, optional
        If given, we will project the ELMo embedding down to this dimension.  We recommend that you
        try using ELMo with a lot of dropout and no projection first, but we have found a few cases
        where projection helps (particularly where there is very limited training data).
    vocab_to_cache : `List[str]`, optional.
        A list of words to pre-compute and cache character convolutions
        for. If you use this option, the ElmoTokenEmbedder expects that you pass word
        indices of shape (batch_size, timesteps) to forward, instead
        of character indices. If you use this option and pass a word which
        wasn't pre-cached, this will break.
    scalar_mix_parameters : `List[int]`, optional, (default=None)
        If not `None`, use these scalar mix parameters to weight the representations
        produced by different layers. These mixing weights are not updated during
        training. The mixing weights here should be the unnormalized (i.e., pre-softmax)
        weights. So, if you wanted to use only the 1st layer of a 2-layer ELMo,
        you can set this to [-9e10, 1, -9e10 ].
    """
    def __init__(
        self,
        options_file: str,
        weight_file: str,
        do_layer_norm: bool = False,
        dropout: float = 0.5,
        requires_grad: bool = False,
        projection_dim: int = None,
        vocab_to_cache: List[str] = None,
        scalar_mix_parameters: List[float] = None,
    ) -> None:
        super().__init__()

        self._elmo = Elmo(
            options_file,
            weight_file,
            1,
            do_layer_norm=do_layer_norm,
            dropout=dropout,
            requires_grad=requires_grad,
            vocab_to_cache=vocab_to_cache,
            scalar_mix_parameters=scalar_mix_parameters,
        )
        if projection_dim:
            self._projection = torch.nn.Linear(self._elmo.get_output_dim(),
                                               projection_dim)
            self.output_dim = projection_dim
        else:
            self._projection = None
            self.output_dim = self._elmo.get_output_dim()

    def get_output_dim(self) -> int:
        return self.output_dim

    def forward(self,
                tokens: torch.Tensor,
                word_inputs: torch.Tensor = None) -> torch.Tensor:
        """
        # Parameters

        tokens : `torch.Tensor`
            Shape `(batch_size, timesteps, 50)` of character ids representing the current batch.
        word_inputs : `torch.Tensor`, optional.
            If you passed a cached vocab, you can in addition pass a tensor of shape
            `(batch_size, timesteps)`, which represent word ids which have been pre-cached.

        # Returns

        The ELMo representations for the input sequence, shape
        `(batch_size, timesteps, embedding_dim)`
        """
        elmo_output = self._elmo(tokens, word_inputs)
        elmo_representations = elmo_output["elmo_representations"][0]
        if self._projection:
            projection = self._projection
            for _ in range(elmo_representations.dim() - 2):
                projection = TimeDistributed(projection)
            elmo_representations = projection(elmo_representations)
        return elmo_representations

    # Custom vocab_to_cache logic requires a from_params implementation.
    @classmethod
    def from_params(  # type: ignore
            cls, vocab: Vocabulary, params: Params,
            **extras) -> "ElmoTokenEmbedder":

        options_file = params.pop("options_file")
        weight_file = params.pop("weight_file")
        requires_grad = params.pop("requires_grad", False)
        do_layer_norm = params.pop_bool("do_layer_norm", False)
        dropout = params.pop_float("dropout", 0.5)
        namespace_to_cache = params.pop("namespace_to_cache", None)
        if namespace_to_cache is not None:
            vocab_to_cache = list(
                vocab.get_token_to_index_vocabulary(namespace_to_cache).keys())
        else:
            vocab_to_cache = None
        projection_dim = params.pop_int("projection_dim", None)
        scalar_mix_parameters = params.pop("scalar_mix_parameters", None)
        params.assert_empty(cls.__name__)
        return cls(
            options_file=options_file,
            weight_file=weight_file,
            do_layer_norm=do_layer_norm,
            dropout=dropout,
            requires_grad=requires_grad,
            projection_dim=projection_dim,
            vocab_to_cache=vocab_to_cache,
            scalar_mix_parameters=scalar_mix_parameters,
        )
Example #16
0
class ModelGraph(nn.Module):
    def __init__(self, general_config, glove_embedding):
        super(ModelGraph, self).__init__()
        self.general_config = general_config
        self.dropout = nn.Dropout(self.general_config.dropout)

        if self.general_config.embedding_model.find("elmo") >= 0:
            elmo_prefix = self.general_config.__dict__[
                self.general_config.embedding_model]
            options_file = elmo_prefix + '_options.json'
            weights_file = elmo_prefix + '_weights.hdf5'
            self.elmo_L = self.general_config.elmo_layer
            self.elmo = Elmo(options_file,
                             weights_file,
                             self.elmo_L,
                             dropout=self.general_config.dropout)
            if self.elmo_L > 1:
                self.elmo_layer_interp = nn.Linear(self.elmo_L, 1, bias=False)
            emb_size = self.elmo.get_output_dim()
        else:
            assert glove_embedding is not None
            self.glove_embedding = nn.Embedding(glove_embedding.shape[0],
                                                glove_embedding.shape[1])
            self.glove_embedding.weight.data.copy_(
                torch.from_numpy(glove_embedding))
            emb_size = glove_embedding.shape[1]

        self.passage_encoder = nn.LSTM(emb_size,
                                       self.general_config.lstm_size,
                                       num_layers=1,
                                       batch_first=True,
                                       bidirectional=True)

        self.question_encoder = nn.LSTM(emb_size,
                                        self.general_config.lstm_size,
                                        num_layers=1,
                                        batch_first=True,
                                        bidirectional=True)

        emb_size = 2 * self.general_config.lstm_size

        if self.general_config.use_mention_feature:
            self.mention_width_embedding = nn.Embedding(
                self.general_config.mention_max_width,
                self.general_config.feature_size)
            self.mention_type_embedding = nn.Embedding(
                4, self.general_config.feature_size)

        mention_emb_size = 2 * emb_size
        if self.general_config.use_mention_head:
            mention_emb_size += emb_size
        if self.general_config.use_mention_feature:
            mention_emb_size += 2 * self.general_config.feature_size

        if self.general_config.use_mention_head:
            self.mention_head_scorer = nn.Linear(emb_size, 1)

        if self.general_config.mention_compress_size > 0:
            self.mention_compressor = nn.Linear(
                mention_emb_size, self.general_config.mention_compress_size)
            mention_emb_size = self.general_config.mention_compress_size

        if self.general_config.graph_encoding == "GRN":
            print("With GRN as the graph encoder")
            self.graph_encoder = grn_encoder_utils.GRNEncoder(
                mention_emb_size, self.general_config.dropout)
        elif self.general_config.graph_encoding == "GCN":
            print("With GCN as the graph encoder")
            self.graph_encoder = gcn_encoder_utils.GCNEncoder(
                mention_emb_size, self.general_config.dropout)
        else:
            self.graph_encoder = None

        if self.general_config.matching_op == "concat":
            self.concat_projector = nn.Linear(mention_emb_size * 2, 1)

        if self.general_config.graph_encoding in ("GRN", "GCN"):
            self.matching_integrater = nn.Linear(
                general_config.graph_encoding_steps + 1, 1, bias=False)

    def get_repre(self, ids):
        if self.general_config.embedding_model.find('elmo') >= 0:
            elmo_outputs = self.elmo(ids)
            if self.elmo_L > 1:
                repre = torch.stack(elmo_outputs['elmo_representations'],
                                    dim=3)  # [batch, seq, emb, L]
                repre = self.elmo_layer_interp(repre).squeeze(
                    dim=3)  # [batch, seq, emb]
            else:
                repre = elmo_outputs['elmo_representations'][
                    0]  # [batch, seq, emb]
            return repre * elmo_outputs['mask'].float().unsqueeze(
                dim=2)  # [batch, seq, emb]
        else:
            return self.glove_embedding(ids)

    def forward(self, batch):
        if self.general_config.embedding_model.find('elmo') >= 0:
            batch_size, passage_max_len, other = list(
                batch['passage_ids'].size())
        else:
            batch_size, passage_max_len = list(batch['passage_ids'].size())
        assert passage_max_len % 10 == 0

        if self.general_config.embedding_model.find('elmo') >= 0:
            passage_ids = batch['passage_ids'].view(
                batch_size * 10, passage_max_len // 10,
                other)  # [batch*10, passage/10, other]
        else:
            passage_ids = batch['passage_ids'].view(
                batch_size * 10,
                passage_max_len // 10)  # [batch*10, passage/10]

        passage_repre = self.get_repre(
            passage_ids)  # [batch*10, passage/10, elmo_emb]
        passage_repre, _ = self.passage_encoder(
            passage_repre)  # [batch*10, passage/10, lstm_emb]
        emb_size = utils.shape(passage_repre, 2)
        passage_repre = passage_repre.contiguous().view(
            batch_size, passage_max_len, emb_size)

        question_repre = self.get_repre(
            batch['question_ids'])  # [batch, question, elmo_emb]
        question_repre, _ = self.question_encoder(
            question_repre)  # [batch, question, lstm_emb]

        # modeling question
        batch_size = len(batch['ids'])
        question_starts = torch.zeros(batch_size, 1,
                                      dtype=torch.long).cuda()  # [batch, 1]
        question_ends = batch['question_lens'].view(batch_size,
                                                    1) - 1  # [batch, 1]
        question_types = torch.zeros(batch_size, 1,
                                     dtype=torch.long).cuda()  # [batch, 1]
        question_mask_float = torch.ones(
            batch_size, 1, dtype=torch.float).cuda()  # [batch, 1]
        question_emb = self.get_mention_embedding(
            question_repre, question_starts, question_ends, question_types,
            question_mask_float).squeeze(dim=1)  # [batch, emb]

        # modeling mentions
        mention_starts = batch['mention_starts']
        mention_ends = batch['mention_ends']
        mention_types = batch['mention_types']
        mention_nums = batch['mention_nums']

        mention_max_num = utils.shape(mention_starts, 1)
        mention_mask = utils.sequence_mask(mention_nums, mention_max_num)
        mention_emb = self.get_mention_embedding(passage_repre, mention_starts,
                                                 mention_ends, mention_types,
                                                 mention_mask.float())

        if self.general_config.mention_compress_size > 0:
            question_emb = self.mention_compressor(question_emb)
            mention_emb = self.mention_compressor(mention_emb)

        matching_results = []
        rst_seq = self.perform_matching(mention_emb, question_emb)
        matching_results.append(rst_seq)

        # graph encoding
        if self.general_config.graph_encoding in ('GCN', 'GRN'):
            if self.general_config.graph_encoding in ("GRN", "GCN"):
                edges = batch['edges']  # [batch, mention, edge]
                edge_nums = batch['edge_nums']  # [batch, mention]
                edge_max_num = utils.shape(edges, 2)
                edge_mask = utils.sequence_mask(
                    edge_nums.view(batch_size * mention_max_num),
                    edge_max_num).view(batch_size, mention_max_num,
                                       edge_max_num)  # [batch, mention, edge]
                assert not (edge_mask &
                            (~mention_mask.unsqueeze(dim=2))).any().item()

            for i in range(self.general_config.graph_encoding_steps):
                mention_emb_new = self.graph_encoder(mention_emb,
                                                     mention_mask.float(),
                                                     edges, edge_mask.float())
                mention_emb = mention_emb_new + mention_emb if self.general_config.graph_residual else mention_emb_new
                rst_graph = self.perform_matching(mention_emb, question_emb)
                matching_results.append(rst_graph)

        if len(matching_results) > 1:
            assert len(matching_results
                       ) == self.general_config.graph_encoding_steps + 1
            matching_results = torch.stack(
                matching_results, dim=2)  # [batch, mention, graph_step+1]
            logits = self.matching_integrater(matching_results).squeeze(
                dim=2)  # [batch, mention]
        else:
            assert len(matching_results) == 1
            logits = matching_results[0]  # [batch, mention]

        candidates, candidate_num, candidate_appear_num = \
                batch['candidates'], batch['candidate_num'], batch['candidate_appear_num']
        _, cand_max_num, cand_pos_max_num = list(candidates.size())

        candidate_mask = utils.sequence_mask(candidate_num,
                                             cand_max_num)  # [batch, cand]
        candidate_appear_mask = utils.sequence_mask(
            candidate_appear_num.view(batch_size * cand_max_num),
            cand_pos_max_num).view(batch_size, cand_max_num,
                                   cand_pos_max_num)  # [batch, cand, pos]
        assert not (candidate_appear_mask &
                    (~candidate_mask.unsqueeze(dim=2))).any().item()

        # ideas to get 'candidate_appear_dist'

        ## idea 1
        #candidate_appear_logits = (utils.batch_gather(logits, candidates) + \
        #        candidate_appear_mask.float().log()).view(batch_size, cand_max_num * cand_pos_max_num) # [batch, cand * pos]
        #candidate_appear_logits = torch.clamp(candidate_appear_logits, -1e1, 1e1) # [batch, cand * pos]
        #candidate_appear_dist = F.softmax(candidate_appear_logits, dim=1).view(batch_size,
        #        cand_max_num, cand_pos_max_num) # [batch, cand, pos]

        ## idea 2
        #candidate_appear_dist = torch.clamp(utils.batch_gather(logits, candidates).exp() * \
        #        candidate_appear_mask.float(), 1e-6, 1e6).view(batch_size, cand_max_num * cand_pos_max_num) # [batch, cand * pos]
        #candidate_appear_dist = candidate_appear_dist / candidate_appear_dist.sum(dim=1, keepdim=True)
        #candidate_appear_dist = candidate_appear_dist.view(batch_size, cand_max_num, cand_pos_max_num)

        ## idea 3
        #candidate_appear_dist = F.softmax(utils.batch_gather(logits, candidates).view(batch_size,
        #        cand_max_num * cand_pos_max_num), dim=1) # [batch, cand * pos]
        #candidate_appear_dist = torch.clamp(candidate_appear_dist * candidate_appear_mask.view(batch_size,
        #        cand_max_num * cand_pos_max_num).float(), 1e-8, 1.0) # [batch, cand * pos]
        #candidate_appear_dist = (candidate_appear_dist / candidate_appear_dist.sum(dim=1, keepdim=True)).view(batch_size,
        #        cand_max_num, cand_pos_max_num) # [batch, cand, pos]

        ## get 'candidate_dist', which is common for idea 1, 2 and 3
        #if not (candidate_appear_dist > 0).all().item():
        #    print(candidate_appear_dist)
        #    assert False
        #candidate_dist = candidate_appear_dist.sum(dim=2) # [batch, cand]

        # original impl
        mention_dist = F.softmax(logits, dim=1)
        if utils.contain_nan(mention_dist):
            print(logits)
            print(mention_dist)
            assert False
        candidate_appear_dist = utils.batch_gather(
            mention_dist, candidates) * candidate_appear_mask.float()
        candidate_dist = candidate_appear_dist.sum(
            dim=2) * candidate_mask.float()
        candidate_dist = utils.clip_and_normalize(candidate_dist, 1e-6)
        assert utils.contain_nan(candidate_dist) == False
        # end of original impl

        candidate_logits = candidate_dist.log()  # [batch, cand]
        predictions = candidate_logits.argmax(dim=1)  # [batch]
        if not (predictions < candidate_num).all().item():
            print(candidate_dist)
            print(candidate_num)
            assert False

        if 'refs' not in batch or batch['refs'] is None:
            return {'predictions': predictions}

        refs = batch['refs']
        loss = nn.CrossEntropyLoss()(candidate_logits, refs)
        right_count = (predictions == refs).sum()
        return {
            'predictions': predictions,
            'loss': loss,
            'right_count': right_count
        }

    # input_repre: [batch, seq, emb]
    # mention_starts, mention_ends and mention_mask: [batch, mentions]
    # s_m(i) = FFNN(g_i)
    # g_i = [x_i^start, x_i^end, x_i^head, \phi(i)]
    def get_mention_embedding(self, input_repre, mention_starts, mention_ends,
                              mention_types, mention_mask_float):
        mention_emb_list = []
        mention_start_emb = utils.batch_gather(
            input_repre, mention_starts)  # [batch, mentions, emb]
        mention_emb_list.append(mention_start_emb)
        mention_end_emb = utils.batch_gather(
            input_repre, mention_ends)  # [batch, mentions, emb]
        mention_emb_list.append(mention_end_emb)

        if self.general_config.use_mention_head:
            batch_size, mention_num = list(mention_starts.size())

            span_starts = mention_starts.unsqueeze(
                dim=2)  # [batch, mentions, 1]
            span_ends = mention_ends.unsqueeze(dim=2)  # [batch, mentions, 1]
            span_range = torch.arange(
                self.general_config.mention_max_width
            ).view(1, 1,
                   self.general_config.mention_max_width)  # [1, 1, span_width]
            if torch.cuda.is_available():
                span_range = span_range.cuda()
            span_indices_raw = span_starts + span_range  # [batch, mentions, span_width]
            span_indices = torch.min(
                span_indices_raw, span_ends)  # [batch, mentions, span_width]
            span_mask = span_indices_raw <= span_ends  # [batch, mention, span_width]

            span_emb = utils.batch_gather(
                input_repre,
                span_indices)  # [batch, mentions, span_width, emb]
            span_scores = self.mention_head_scorer(span_emb).squeeze(dim=-1) + \
                    torch.log(span_mask.float()) # [batch, mentions, seq_width]
            span_attn = F.softmax(span_scores,
                                  dim=-1)  # [batch, mentions, span_width]
            mention_head_emb = span_emb * span_attn.unsqueeze(
                dim=-1)  # [batch, mentions, span_width, emb]
            mention_head_emb = torch.sum(mention_head_emb,
                                         dim=2)  # [batch, mentions, emb]
            mention_emb_list.append(mention_head_emb)

        if self.general_config.use_mention_feature:
            mention_width = 1 + mention_ends - mention_starts  # [batch, mentions]
            mention_width_index = torch.clamp(
                mention_width, 1,
                self.general_config.mention_max_width) - 1  # [batch, mentions]
            mention_width_emb = self.mention_width_embedding(
                mention_width_index)  # [batch, mentions, emb]
            mention_width_emb = self.dropout(mention_width_emb)
            mention_emb_list.append(mention_width_emb)
            mention_type_emb = self.mention_type_embedding(mention_types)
            mention_emb_list.append(mention_type_emb)

        return torch.cat(mention_emb_list,
                         dim=2) * mention_mask_float.unsqueeze(dim=2)

    # mention_emb: [batch, mention, emb]
    # question_emb: [batch, emb]
    def perform_matching(self, mention_emb, question_emb):
        if self.general_config.matching_op == "matmul":
            # [batch, mention, emb] * [batch, emb, 1] ==> [batch, mention, 1]
            logits = mention_emb.matmul(
                question_emb.unsqueeze(dim=2)).squeeze(dim=2)
            return logits
        elif self.general_config.matching_op == "concat":
            mention_max_num = utils.shape(mention_emb, 1)
            question_emb = question_emb.unsqueeze(dim=1).expand(
                -1, mention_max_num, -1)
            combined_emb = torch.cat([mention_emb, question_emb],
                                     dim=2)  # [batch, mention, emb]
            logits = self.concat_projector(combined_emb).squeeze(
                dim=2)  # [batch, mention]
            return logits
        else:
            assert False, "Unsupported matching_op: {}".format(
                self.general_config.matching_op)
Example #17
0
class EBMNLPTagger(pl.LightningModule):
    def __init__(self, hparams):
        """
        input:
            hparams: namespace with the following items:
                'data_dir' (str): Data Directory. default: './official/ebm_nlp_1_00'
                'bioelmo_dir' (str): BioELMo Directory. default: './models/bioelmo', help='BioELMo Directory')
                'max_length' (int): Max Length. default: 1024
                'lr' (float): Learning Rate. default: 1e-2
                'fine_tune_bioelmo' (bool): Whether to Fine Tune BioELMo. default: False
                'lr_bioelmo' (float): Learning Rate in BioELMo Fine-tuning. default: 1e-4
        """
        super().__init__()
        self.hparams = hparams
        self.itol = ID_TO_LABEL
        self.ltoi = {v: k for k, v in self.itol.items()}

        # Load Pretrained BioELMo
        DIR_ELMo = Path(str(self.hparams.bioelmo_dir))
        self.bioelmo = Elmo(DIR_ELMo / 'biomed_elmo_options.json',
                            DIR_ELMo / 'biomed_elmo_weights.hdf5',
                            1,
                            requires_grad=bool(self.hparams.fine_tune_bioelmo),
                            dropout=0)
        self.bioelmo_output_dim = self.bioelmo.get_output_dim()

        # ELMo Padding token (In ELMo token with ID 0 is used for padding)
        VOCAB_FILE_PATH = DIR_ELMo / 'vocab.txt'
        command = shlex.split(f"head -n 1 {VOCAB_FILE_PATH}")
        res = subprocess.Popen(command, stdout=subprocess.PIPE)
        self.bioelmo_pad_token = res.communicate()[0].decode('utf-8').strip()

        # Initialize Intermediate Affine Layer
        self.hidden_to_tag = nn.Linear(int(self.bioelmo_output_dim),
                                       len(self.itol))

        # Initialize CRF
        TRANSITIONS = conditional_random_field.allowed_transitions(
            constraint_type='BIO', labels=self.itol)
        self.crf = conditional_random_field.ConditionalRandomField(
            # set to 7 because here "tags" means ['O', 'B-P', 'I-P', 'B-I', 'I-I', 'B-O', 'I-O']
            # no need to include 'BOS' and 'EOS' in "tags"
            num_tags=len(self.itol),
            constraints=TRANSITIONS,
            include_start_end_transitions=False)
        self.crf.reset_parameters()

    def get_device(self):
        return self.crf.state_dict()['transitions'].device

    def _forward_crf(self, hidden, gold_tags_padded, crf_mask):
        """
        input:
            hidden (torch.tensor) (n_batch, seq_length, hidden_dim)
            gold_tags_padded (torch.tensor) (n_batch, seq_length)
            crf_mask (torch.bool) (n_batch, seq_length)
        output:
            result (dict)
                'log_likelihood' : torch.tensor
                'pred_tags_packed' : torch.nn.utils.rnn.PackedSequence
                'gold_tags_padded' : torch.tensor
        """
        result = {}

        if gold_tags_padded is not None:
            # Log likelihood
            log_prob = self.crf.forward(hidden, gold_tags_padded, crf_mask)

            # top k=1 tagging
            Y = [
                torch.tensor(result[0])
                for result in self.crf.viterbi_tags(logits=hidden,
                                                    mask=crf_mask)
            ]
            Y = rnn.pack_sequence(Y, enforce_sorted=False)

            result['log_likelihood'] = log_prob
            result['pred_tags_packed'] = Y
            result['gold_tags_padded'] = gold_tags_padded
            return result

        else:
            # top k=1 tagging
            Y = [
                torch.tensor(result[0])
                for result in self.crf.viterbi_tags(logits=hidden,
                                                    mask=crf_mask)
            ]
            Y = rnn.pack_sequence(Y, enforce_sorted=False)
            result['pred_tags_packed'] = Y
            return result

    def forward(self, tokens, gold_tags=None):
        """
        input:
            hidden (torch.tensor) (n_batch, seq_length, hidden_dim)
            gold_tags_padded (torch.tensor) (n_batch, seq_length)
            crf_mask (torch.bool) (n_batch, seq_length)
        output:
            result (dict)
                'log_likelihood' : torch.tensor
                'pred_tags_packed' : torch.nn.utils.rnn.PackedSequence
                'gold_tags_padded' : torch.tensor
        """
        # character_ids: torch.tensor(n_batch, len_max)
        character_ids = batch_to_ids(tokens)
        character_ids = character_ids[:, :self.hparams.max_length, :]
        character_ids = character_ids.to(self.get_device())

        # characted_ids -> BioELMo hidden state of the last layer & mask
        out = self.bioelmo(character_ids)

        # Turn on gradient tracking
        # Affine transformation (Hidden_dim -> N_tag)
        hidden = out['elmo_representations'][-1]
        hidden.requires_grad_()
        hidden = self.hidden_to_tag(hidden)

        crf_mask = out['mask'].to(torch.bool).to(self.get_device())

        if gold_tags is not None:
            gold_tags = [torch.tensor(seq) for seq in gold_tags]
            gold_tags_padded = rnn.pad_sequence(gold_tags,
                                                batch_first=True,
                                                padding_value=self.ltoi['O'])
            gold_tags_padded = gold_tags_padded[:, :self.hparams.max_length]
            gold_tags_padded = gold_tags_padded.to(self.get_device())
        else:
            gold_tags_padded = None

        result = self._forward_crf(hidden, gold_tags_padded, crf_mask)
        return result

    def step(self, batch, batch_nb, *optimizer_idx):
        tokens_nopad = batch['tokens']
        tags_nopad = batch['tags']

        # Negative Log Likelihood
        result = self.forward(tokens_nopad, tags_nopad)
        returns = {
            'loss': result['log_likelihood'] * (-1.0),
            'T': result['gold_tags_padded'],
            'Y': result['pred_tags_packed'],
            'I': batch['pmid']
        }
        return returns

    def unpack_pred_tags(self, Y_packed):
        """
        input:
            Y_packed: torch.nn.utils.rnn.PackedSequence
        output:
            Y: list(list(str))
                Predicted NER tagging sequence.
        """
        Y_padded, Y_len = rnn.pad_packed_sequence(Y_packed,
                                                  batch_first=True,
                                                  padding_value=-1)
        Y_padded = Y_padded.numpy().tolist()
        Y_len = Y_len.numpy().tolist()

        # Replace B- tag with I- tag because the original paper defines the NER task as identification of spans, not entities
        Y = [[self.itol[ix].replace('B-', 'I-') for ix in ids[:length]]
             for ids, length in zip(Y_padded, Y_len)]

        return Y

    def unpack_gold_and_pred_tags(self, T_padded, Y_packed):
        """
        input:
            T_padded: torch.tensor
            Y_packed: torch.nn.utils.rnn.PackedSequence
        output:
            T: list(list(str))
                Gold NER tagging sequence.
            Y: list(list(str))
                Predicted NER tagging sequence.
        """
        Y = self.unpack_pred_tags(Y_packed)
        Y_len = [len(seq) for seq in Y]

        T_padded = T_padded.numpy().tolist()

        # Replace B- tag with I- tag because the original paper defines the NER task as identification of spans, not entities
        T = [[self.itol[ix].replace('B-', 'I-') for ix in ids[:length]]
             for ids, length in zip(T_padded, Y_len)]

        return T, Y

    def gather_outputs(self, outputs):
        if len(outputs) > 1:
            loss = torch.mean(
                torch.tensor([output['loss'] for output in outputs]))
        else:
            loss = torch.mean(outputs[0]['loss'])

        I = []
        Y = []
        T = []

        for output in outputs:
            T_batch, Y_batch = self.unpack_gold_and_pred_tags(
                output['T'].cpu(), output['Y'].cpu())
            T += T_batch
            Y += Y_batch
            I += output['I'].cpu().numpy().tolist()

        returns = {'loss': loss, 'T': T, 'Y': Y, 'I': I}

        return returns

    def training_step(self, batch, batch_nb, *optimizer_idx):
        # Process on individual mini-batches
        """
        (batch) -> (dict or OrderedDict)
        # Caution: key for loss function must exactly be 'loss'.
        """
        return self.step(batch, batch_nb, *optimizer_idx)

    def training_step_end(self, outputs):
        """
        outputs(dict) -> loss(dict or OrderedDict)
        # Caution: key must exactly be 'loss'.
        """
        loss = torch.mean(outputs['loss'])

        progress_bar = {'train_loss': loss}
        returns = {
            'loss': loss,
            'T': outputs['T'],
            'Y': outputs['Y'],
            'I': outputs['I'],
            'progress_bar': progress_bar
        }
        return returns

    def training_epoch_end(self, outputs):
        """
        outputs(list of dict) -> loss(dict or OrderedDict)
        # Caution: key must exactly be 'loss'.
        """
        outs = self.gather_outputs(outputs)
        loss = outs['loss']
        I = outs['I']
        Y = outs['Y']
        T = outs['T']

        get_logger(self.hparams.version).info(
            f'========== Training Epoch {self.current_epoch} ==========')
        get_logger(self.hparams.version).info(f'Loss: {loss.item()}')
        get_logger(self.hparams.version).info(
            f'Entity-wise classification report\n{seq_classification_report(T, Y, 4)}'
        )
        get_logger(self.hparams.version).info(
            f'Token-wise classification report\n{span_classification_report(T, Y, 4)}'
        )

        progress_bar = {'train_loss': loss}
        returns = {'loss': loss, 'progress_bar': progress_bar}
        return returns

    def validation_step(self, batch, batch_nb):
        # Process on individual mini-batches
        """
        (batch) -> (dict or OrderedDict)
        """
        return self.step(batch, batch_nb)

    def validation_end(self, outputs):
        """
        For single dataloader:
            outputs(list of dict) -> (dict or OrderedDict)
        For multiple dataloaders:
            outputs(list of (list of dict)) -> (dict or OrderedDict)
        """
        outs = self.gather_outputs(outputs)
        loss = outs['loss']
        I = outs['I']
        Y = outs['Y']
        T = outs['T']

        get_logger(self.hparams.version).info(
            f'========== Validation Epoch {self.current_epoch} ==========')
        get_logger(self.hparams.version).info(f'Loss: {loss.item()}')
        get_logger(self.hparams.version).info(
            f'Entity-wise classification report\n{seq_classification_report(T, Y, 4)}'
        )
        get_logger(self.hparams.version).info(
            f'Token-wise classification report\n{span_classification_report(T, Y, 4)}'
        )

        progress_bar = {'val_loss': loss}
        returns = {'loss': loss, 'progress_bar': progress_bar}
        return returns

    def test_step(self, batch, batch_nb):
        # Process on individual mini-batches
        """
        (batch) -> (dict or OrderedDict)
        """
        return self.step(batch, batch_nb)

    def test_epoch_end(self, outputs):
        """
        For single dataloader:
            outputs(list of dict) -> (dict or OrderedDict)
        For multiple dataloaders:
            outputs(list of (list of dict)) -> (dict or OrderedDict)
        """
        outs = self.gather_outputs(outputs)
        loss = outs['loss']
        I = outs['I']
        Y = outs['Y']
        T = outs['T']

        get_logger(self.hparams.version).info(f'========== Test ==========')
        get_logger(self.hparams.version).info(f'Loss: {loss.item()}')
        get_logger(self.hparams.version).info(
            f'Entity-wise classification report\n{seq_classification_report(T, Y, 4)}'
        )
        get_logger(self.hparams.version).info(
            f'Token-wise classification report\n{span_classification_report(T, Y, 4)}'
        )

        progress_bar = {'test_loss': loss}
        returns = {'loss': loss, 'progress_bar': progress_bar}
        return returns

    def configure_optimizers(self):
        if self.hparams.fine_tune_bioelmo:
            optimizer_bioelmo_1 = optim.Adam(self.bioelmo.parameters(),
                                             lr=float(self.hparams.lr_bioelmo))
            optimizer_bioelmo_2 = optim.Adam(self.hidden_to_tag.parameters(),
                                             lr=float(self.hparams.lr_bioelmo))
            optimizer_crf = optim.Adam(self.crf.parameters(),
                                       lr=float(self.hparams.lr))
            return [optimizer_bioelmo_1, optimizer_bioelmo_2, optimizer_crf]
        else:
            optimizer = optim.Adam(self.parameters(),
                                   lr=float(self.hparams.lr))
            return optimizer

    def train_dataloader(self):
        ds_train_val = EBMNLPDataset(
            *path_finder(self.hparams.data_dir)['train'])

        ds_train, _ = train_test_split(ds_train_val,
                                       train_size=0.8,
                                       random_state=self.hparams.random_state)
        dl_train = EBMNLPDataLoader(ds_train,
                                    batch_size=self.hparams.batch_size,
                                    shuffle=True)
        return dl_train

    def val_dataloader(self):
        ds_train_val = EBMNLPDataset(
            *path_finder(self.hparams.data_dir)['train'])

        _, ds_val = train_test_split(ds_train_val,
                                     train_size=0.8,
                                     random_state=self.hparams.random_state)
        dl_val = EBMNLPDataLoader(ds_val,
                                  batch_size=self.hparams.batch_size,
                                  shuffle=False)
        return dl_val

    def test_dataloader(self):
        ds_test = EBMNLPDataset(*path_finder(self.hparams.data_dir)['test'])
        dl_test = EBMNLPDataLoader(ds_test,
                                   batch_size=self.hparams.batch_size,
                                   shuffle=False)
        return dl_test
Example #18
0
class Seq2IdxSum(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
        word_embedding_dim: int = 200,
        hidden_dim: int = 200,
        dropout_emb: float = 0.5,
        min_dec_step: int = 2,
        max_decoding_steps=3,
        fix_edu_num=-1,
        dropout: float = 0.5,
        alpha: float = 0.5,
        span_encoder_type='self_attentive',
        use_elmo: bool = True,
        attn_type: str = 'general',
        schedule_ratio_from_ground_truth: float = 0.8,
        pretrain_embedding_file=None,
        nenc_lay: int = 2,
        mult_orac_sampling: bool = False,
        word_token_indexers=None,
        compression: bool = True,
        dbg: bool = False,
        dec_avd_trigram_rep: bool = True,
        aggressive_compression: int = -1,
        compress_leadn: int = -1,
        subsentence: bool = False,
        gather='mean',
        keep_threshold: float = 0.5,
        abs_board_file: str = "/home/cc/exComp/board.txt",
        abs_dir_root: str = "/scratch/cluster/jcxu",
        serilization_name: str = "",
    ) -> None:

        super(Seq2IdxSum, self).__init__(vocab, regularizer)
        self.text_field_embedder = text_field_embedder

        elmo_weight = os.path.join(
            abs_dir_root, "elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5")
        # if not os.path.isfile(elmo_weight):
        #     import subprocess
        #     x = "wget https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5 -P {}".format(abs_dir_root)
        #     subprocess.run(x.split(" "))

        self.device = get_device()
        self.vocab = vocab
        self.dbg = dbg
        self.loss_thres = keep_threshold
        self.compression = compression
        self.comp_leadn = compress_leadn
        # Just encode the whole document without looking at compression options
        self.enc_doc = EncDoc(inp_dim=word_embedding_dim,
                              hid_dim=hidden_dim,
                              vocab=vocab,
                              dropout=dropout,
                              dropout_emb=dropout_emb,
                              pretrain_embedding_file=pretrain_embedding_file,
                              gather=gather)

        self.sent_dec = SentRNNDecoder(
            rnn_type='lstm',
            dec_hidden_size=self.enc_doc.get_output_dim(),
            dec_input_size=self.enc_doc.get_output_dim(),
            dropout=dropout,
            fixed_dec_step=fix_edu_num,
            max_dec_steps=max_decoding_steps,
            min_dec_steps=min_dec_step,
            schedule_ratio_from_ground_truth=schedule_ratio_from_ground_truth,
            dec_avd_trigram_rep=dec_avd_trigram_rep,
            mult_orac_sample_one=mult_orac_sampling,
            abs_board_file=abs_board_file,
            valid_tmp_path=abs_dir_root,
            serilization_name=serilization_name)
        if compression:
            self.compression_dec = CompressDecoder(
                context_dim=hidden_dim * 2,
                dec_state_dim=hidden_dim * 2,
                enc_hid_dim=hidden_dim,
                text_field_embedder=self.enc_doc._text_field_embedder,
                aggressive_compression=aggressive_compression,
                keep_threshold=keep_threshold,
                abs_board_file=abs_board_file,
                gather=gather,
                dropout=dropout,
                dropout_emb=dropout_emb,
                valid_tmp_path=abs_dir_root,
                serilization_name=serilization_name,
                vocab=vocab,
                elmo=use_elmo,
                elmo_weight=elmo_weight)
            self.aggressive_compression = aggressive_compression

        self.use_elmo = use_elmo
        if use_elmo:
            options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
            weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
            self.elmo = Elmo(options_file, weight_file, 1, dropout=0)
            # print(self.elmo.get_output_dim())
            self._context_layer = PytorchSeq2SeqWrapper(
                torch.nn.LSTM(word_embedding_dim + self.elmo.get_output_dim(),
                              hidden_dim,
                              batch_first=True,
                              bidirectional=True))
        else:

            self._context_layer = PytorchSeq2SeqWrapper(
                torch.nn.LSTM(word_embedding_dim,
                              hidden_dim,
                              batch_first=True,
                              bidirectional=True))

        token_embedding = Embedding(
            num_embeddings=vocab.get_vocab_size('tokens'),
            embedding_dim=word_embedding_dim)
        if pretrain_embedding_file is not None:
            logger = logging.getLogger()
            logger.info(
                "Loading word embedding: {}".format(pretrain_embedding_file))
            token_embedding.from_params(vocab=vocab,
                                        params=Params({
                                            "pretrained_file":
                                            pretrain_embedding_file,
                                            "embedding_dim":
                                            word_embedding_dim
                                        }))
        self._text_field_embedder = BasicTextFieldEmbedder(
            {"tokens": token_embedding})

        # if span_encoder_type == 'self_attentive':
        #     self._span_encoder = SelfAttentiveSpanExtractor(
        #         self._context_layer.get_output_dim()
        #     )
        # else:
        #     raise NotImplementedError

        self._dropout = torch.nn.Dropout(p=dropout)
        self._max_decoding_steps = max_decoding_steps
        self._fix_edu_num = fix_edu_num
        if compression:
            pass
            # self.rouge_metrics_compression = self.compression_dec.rouge_metrics_compression
            # self.rouge_metrics_compression_upper_bound = self.compression_dec.rouge_metrics_compression_best_possible
        self.rouge_metrics_sent = self.sent_dec.rouge_metrics_sent
        self.mult_orac_sampling = mult_orac_sampling
        self.alpha = alpha
        initializer(self)
        if regularizer is not None:
            regularizer(self)
        self.counter = 0  # used for controlling compression and extraction

    @overrides
    def forward(self, text: Dict[str, torch.LongTensor], sent_label,
                sent_rouge, comp_rouge, comp_msk, comp_meta, comp_rouge_ratio,
                comp_seq_label, metadata):
        """

        :param text: input words. [batch, max_sent, max_word]  0 for padding bit
        :param sent_label: [batchsize, n_oracles, max_decoding_step] [1,3,6,-1,-1] sorted descendingly by rouge
        :param sent_rouge: [batchsize, n_oracles]
        :param comp_rouge: [batchsize, max_sent, max_num_compression] padding with 0
        :param comp_msk: [batchsize, max_sent, max_num_compression, max_word] deletion mask of compression.
        :param comp_meta: ScatterableList
        :param comp_rouge_ratio: [batchsize, max_sent, max_num_compression] after-compression-rouge / baseline rouge
        :param comp_seq_label: [batchsize, max_sent] index of best compression. padded with -1.
        :param metadata: ScatterableList
        :return:
        """

        batch, sent_num, max_word_num = text['tokens'].size()
        batch_, nora, _max_dec = sent_label.size()
        batchsz, ora_ = sent_rouge.size()
        batchsize, max_sent, max_compre = comp_rouge.size()
        batchsize_, _max_sent, _max_compre, _max_word_num = comp_msk.size()
        bz, max_s = comp_seq_label.size()
        assert batchsize == bz == batchsize_ == batch_ == batch == batchsz
        assert sent_num == max_sent == max_s
        assert _max_word_num == max_word_num
        text_mask = util.get_text_field_mask(text, num_wrapping_dims=1).float()
        sent_blstm_output, document_rep = self.enc_doc.forward(
            inp=text, context_msk=text_mask)
        # sent_blstm_output: [batch, sent_num, hdim*2]
        # document_rep: [batch, hdim*2]
        sent_mask = text_mask[:, :, 0]

        if self.training:
            decoder_outputs_logit, decoder_outputs_prob, [decoder_states_h, decoder_states_c] = \
                self.sent_dec.forward(context=sent_blstm_output,
                                      context_mask=sent_mask,  # batch size, enc sent num; [1,0]
                                      last_state=document_rep,  # batch size, dim;
                                      tgt=sent_label)
        else:
            decoder_outputs_logit, decoder_outputs_prob, [decoder_states_h, decoder_states_c] = \
                self.sent_dec.forward(context=sent_blstm_output,
                                      context_mask=sent_mask,  # batch size, enc sent num; [1,0]
                                      last_state=document_rep,  # batch size, dim;
                                      tgt=None)
        # Compute sent loss
        decoder_outputs_prob = flip_first_two_dim(decoder_outputs_prob)
        if not self.training:
            sent_label = sent_label[:, :, :self.sent_dec.max_dec_steps]
        sent_loss, ori_loss = self.sent_dec.comp_loss(
            decoder_outputs_prob=decoder_outputs_prob,
            oracles=sent_label,
            rouge=sent_rouge)

        # comp subsentence model
        # refine_subsent_selection() default is root        --subsentence
        sent_emission = self.refine_sent_selection(batchsz, self.comp_leadn,
                                                   decoder_outputs_logit,
                                                   sent_label,
                                                   decoder_outputs_prob,
                                                   metadata)
        # run compression module
        if self.compression:
            # sent_label or decoder_outputs_logit: batch, t

            sent_emission = sent_emission.detach()
            ####
            #  sent_emission: t, batch_sz. [4, 13, 0, -1, -1]...
            ####

            assert sent_emission.size()[1] == batch_

            all_attn_dist, all_scores, all_reps = self.compression_dec.forward_parallel(
                sent_decoder_states=decoder_states_h,
                sent_decoder_outputs_logit=sent_emission,
                document_rep=document_rep,
                text=text,
                text_msk=text_mask,
                span=comp_msk)
            # all_reps: t, batch_size_, max_span_num, self.concat_size

            if self.aggressive_compression > 0:
                compression_loss = self.compression_dec.comp_loss(
                    sent_emission, all_scores, comp_seq_label, comp_rouge,
                    comp_rouge_ratio)
            elif self.aggressive_compression < 0:
                # Independent Classifier
                span_score = self.compression_dec.indep_compression_judger(
                    all_reps)
                # span_prob: t, batch, max_span_num, 2
                compression_loss = self.compression_dec.comp_loss_inf_deletion(
                    sent_emission, comp_rouge, span_score, comp_rouge_ratio,
                    self.loss_thres)
            else:
                raise NotImplementedError

        else:
            compression_loss = 0
        # Decoding:
        if (self.dbg is True) or (self.training is False):

            if self.compression:
                if self.aggressive_compression > 0:
                    self.compression_dec.decode(
                        decoder_outputs_logit=sent_emission,
                        span_score=all_scores,
                        metadata=metadata,
                        span_meta=comp_meta,
                        span_seq_label=comp_seq_label,
                        span_rouge=comp_rouge,
                        compress_num=self.aggressive_compression)
                elif self.aggressive_compression < 0:
                    # for thres in self.compression_dec.keep_thres:
                    span_score = span_score.detach()
                    self.compression_dec.decode_inf_deletion(
                        sent_decoder_outputs_logit=sent_emission,
                        span_prob=span_score,
                        metadata=metadata,
                        span_meta=comp_meta,
                        span_rouge=comp_rouge,
                        keep_threshold=self.compression_dec.keep_thres)
                else:
                    raise NotImplementedError

        if self.compression:
            if random.random() < 0.002:
                print("Comp loss: {}".format(compression_loss))
            if self.comp_leadn > 0:
                return {"loss": compression_loss}
            else:
                # print("sent: {}\tcomp: {}".format(sent_loss, compression_loss))
                return {
                    "loss": sent_loss + self.alpha * compression_loss,
                    "sent_loss": sent_loss,
                    "compression_loss": compression_loss
                }
        else:
            return {"loss": sent_loss, "sent_loss": ori_loss}

    def get_metrics(self, reset: bool = False, note="") -> Dict[str, float]:
        # ROUGE
        _rouge_met_sent = self.rouge_metrics_sent.get_metric(reset=reset,
                                                             note=note)
        if self.compression:
            # _rouge_met_compression_ub = self.rouge_metrics_compression_upper_bound.get_metric(reset, note=note)
            if self.aggressive_compression < 0:
                dic = self.compression_dec.rouge_metrics_compression_dict
                new_dict = para_get_metric(dic, reset, note)
                return {**new_dict, **_rouge_met_sent}
            else:
                pass
                # _rouge_met_compression = self.rouge_metrics_compression.get_metric(reset=reset, note=note)
        else:
            return _rouge_met_sent

    def decode(self,
               output_dict: Dict[str, torch.Tensor],
               max_decoding_steps: int = 6,
               fix_edu_num: int = -1,
               min_step_decoding=2) -> Dict[str, torch.Tensor]:
        """

        :param output_dict: ["decoder_outputs_logit", "decoder_outputs_prob"
            "spans", "loss", "label",
            "metadata"["doc_list", "abs_list"] ]
        :param max_decoding_steps:
        :return:
        """
        assert output_dict["loss"] is not None
        meta = output_dict["metadata"]
        output_logit = output_dict[
            "decoder_outputs_logit"][:max_decoding_steps, :]
        output_logit = output_logit.cpu().numpy()

        output_prob = output_dict["decoder_outputs_prob"]

        span_info = output_dict["spans"]
        label = output_dict["label"]

        batch_size = len(meta)
        for idx, m in enumerate(meta):
            _label = label[idx]  # label: batch, step
            logit = output_logit[:, idx]
            prob = output_prob[:, idx, :]  # step, src_len
            sp = span_info[idx]
            name = m["name"]
            formal_doc = m['doc_list']
            formal_abs = m['abs_list']
            abs_s = convert_list_to_paragraph(formal_abs)

            _pred = []

            if fix_edu_num:
                prob = prob[:fix_edu_num, :]
                prob[:, 0] = -1000
                max_idx = torch.argmax(prob, dim=1)
                # predict exactly fix_edu_num of stuff
                # if 0 or unreachable, use prob
                logit = max_idx.cpu().numpy()
                for l in logit:
                    try:
                        start_idx = int(sp[l][0].item())
                        end_idx = int(sp[l][1].item())
                        words = formal_doc[start_idx:end_idx + 1]  # list
                        _pred.append(' '.join(words).replace('@@SS@@', ''))
                    except IndexError:
                        logging.error("----Out of range-----")
            else:
                # reach minimum requirement (2) of edu num and follow the prediction
                max_idx = torch.argmax(prob, dim=1)
                logit = max_idx.cpu().numpy()

                prob[:, 0] = -1000
                backup_max_idx = torch.argmax(prob, dim=1)
                backup_logit = backup_max_idx.cpu().numpy()

                for t, l in enumerate(logit):
                    if t < min_step_decoding and abs(l) < 0.01:
                        l = backup_logit[t]
                    elif abs(l) < 0.01:
                        break
                    try:
                        start_idx = int(sp[l][0].item())
                        end_idx = int(sp[l][1].item())
                        words = formal_doc[start_idx:end_idx + 1]  # list
                        _pred.append(' '.join(words).replace('@@SS@@', ''))
                    except IndexError:
                        logging.error("----Out of range-----")

            if random.random() < 0.1:
                log_predict_example(name=name,
                                    pred_label=logit,
                                    gold_label=_label,
                                    pred_abs=_pred,
                                    gold_abs=abs_s)
            self.rouge_metrics(pred=_pred, ref=[abs_s])
        return output_dict

    def refine_sent_selection(self, batchsz, comp_leadn, decoder_outputs_logit,
                              sent_label, decoder_outputs_prob, metadata):

        if comp_leadn > 0:
            part = metadata[0]['part']
            if part == 'cnn':
                comp_leadn -= 1

            lead3 = torch.ones_like(decoder_outputs_logit,
                                    dtype=torch.long,
                                    device=self.device) * -1
            assert decoder_outputs_logit.size()[1] == batchsz
            _t = decoder_outputs_logit.size()[0]

            for i in range(comp_leadn):
                if _t > i and comp_leadn >= i:
                    lead3[i, :] = i

            sent_emission = lead3
        else:
            rand_num = random.random()
            if self.training and (rand_num < 0.9):
                # use ground truth
                sent_emission = sent_label[:, 0, :]
                sent_emission = flip_first_two_dim(sent_emission.long())
            else:
                sent_decoded = self.sent_dec.decode(decoder_outputs_prob,
                                                    metadata, sent_label)
                sent_emission = sent_decoded
                # print(sent_emission.size()[0])
                # print(decoder_outputs_logit.size()[0])
                # assert sent_emission.size()[0] == decoder_outputs_logit.size()[0]
        return sent_emission
Example #19
0
class LexiconEncoder(nn.Module):
    """
    Each token p_i in the passasge is represented as a 600-dimensional vector
    and each token q_i is represented as 300-dimensional vector.
    """
    def create_embed(self, vocab_size, embed_dim, padding_idx=0):
        return nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)

    def create_word_embed(self, embedding=None, opt={}, prefix='wemb'):
        vocab_size = opt.get('vocab_size', 1)
        embed_dim = opt.get('{}_dim'.format(prefix), 300)
        self.embedding = self.create_embed(vocab_size, embed_dim)
        if embedding is not None:
            self.embedding.weight.data = embedding
            if opt['fix_embeddings'] or opt['tune_partial'] == 0:
                opt['fix_embeddings'] = True
                opt['tune_partial'] = 0
                for p in self.embedding.parameters():
                    p.requires_grad = False
            else:
                assert opt['tune_partial'] < embedding.size(0)
                fixed_embedding = embedding[opt['tune_partial']:]
                self.register_buffer('fixed_embedding', fixed_embedding)
                self.fixed_embedding = fixed_embedding
        return embed_dim

    def create_pos_embed(self, opt={}, prefix='pos'):
        vocab_size = opt.get('{}_vocab_size'.format(prefix), 56)
        embed_dim = opt.get('{}_dim'.format(prefix), 12)
        self.pos_embedding = self.create_embed(vocab_size, embed_dim)
        return embed_dim

    def create_ner_embed(self, opt={}, prefix='ner'):
        vocab_size = opt.get('{}_vocab_size'.format(prefix), 19)
        embed_dim = opt.get('{}_dim'.format(prefix), 8)
        self.ner_embedding = self.create_embed(vocab_size, embed_dim)
        return embed_dim

    def create_elmo_embed(self, opt={}, prefix='elmo'):
        # TODO
        options_file = os.path.join(opt['data_dir'],
                                    opt.get('{}_options_file'.format(prefix)))
        weights_file = os.path.join(opt['data_dir'],
                                    opt.get('{}_weights_file'.format(prefix)))
        self.elmo = Elmo(options_file, weights_file, 2, dropout=0)

        self.elmo_output_dim = self.elmo.get_output_dim()

        return self.elmo_output_dim

    def create_cove(self,
                    vocab_size,
                    embedding=None,
                    embed_dim=300,
                    padding_idx=0,
                    opt=None):
        self.ContextualEmbed = ContextualEmbed(os.path.join(
            opt['data_dir'], opt['covec_path']),
                                               opt['vocab_size'],
                                               embedding=embedding,
                                               padding_idx=padding_idx)
        return self.ContextualEmbed.output_size

    def create_prealign(self, x1_dim, x2_dim, opt={}, prefix='prealign'):
        self.prealign = AttentionWrapper(x1_dim, x2_dim, prefix, opt,
                                         self.dropout)

    def __init__(self,
                 opt,
                 pwnn_on=True,
                 embedding=None,
                 padding_idx=0,
                 dropout=None):
        super(LexiconEncoder, self).__init__()
        doc_input_size = 0
        que_input_size = 0
        self.dropout = DropoutWrapper(
            opt['dropout_p']) if dropout == None else dropout
        self.dropout_emb = DropoutWrapper(opt['dropout_emb'])
        # word embedding
        embedding_dim = self.create_word_embed(embedding, opt)
        self.embedding_dim = embedding_dim
        doc_input_size += embedding_dim
        que_input_size += embedding_dim

        # elmo
        elmo_size = self.create_elmo_embed(opt=opt) if opt['elmo_on'] else 0
        self.elmo_size = elmo_size

        # pre-trained contextual vector
        covec_size = self.create_cove(opt['vocab_size'], embedding,
                                      opt=opt) if opt['covec_on'] else 0
        self.covec_size = covec_size

        prealign_size = 0
        if opt['prealign_on'] and embedding_dim > 0:
            prealign_size = embedding_dim
            self.create_prealign(embedding_dim, embedding_dim, opt)
        self.prealign_size = prealign_size
        pos_size = self.create_pos_embed(opt) if opt['pos_on'] else 0
        ner_size = self.create_ner_embed(opt) if opt['ner_on'] else 0
        feat_size = opt['num_features'] if opt['feat_on'] else 0
        print(feat_size)
        doc_hidden_size = embedding_dim + covec_size + prealign_size + pos_size + ner_size + feat_size + elmo_size
        que_hidden_size = embedding_dim + covec_size + elmo_size
        if opt['prealign_bidi']:
            que_hidden_size += prealign_size
        self.pwnn_on = pwnn_on
        self.opt = opt
        if self.pwnn_on:
            # To map both passage and question lexical encodings into the same dimension.
            self.doc_pwnn = PositionwiseNN(doc_hidden_size,
                                           opt['pwnn_hidden_size'], dropout)
            if doc_hidden_size == que_hidden_size:
                self.que_pwnn = self.doc_pwnn
            else:
                self.que_pwnn = PositionwiseNN(que_hidden_size,
                                               opt['pwnn_hidden_size'],
                                               dropout)
            doc_input_size, que_input_size = opt['pwnn_hidden_size'], opt[
                'pwnn_hidden_size']
        self.doc_input_size = doc_input_size
        self.query_input_size = que_input_size

    def patch(self, v):
        if self.opt['cuda']:
            v = Variable(v.cuda(async=True))
        else:
            v = Variable(v)
        return v

    def get_elmo_emb(self, char_ids):
        if self.opt['cuda']:
            char_ids = Variable(char_ids.cuda(async=True), requires_grad=False)
        else:
            char_ids = Variable(char_ids, requires_grad=False)

        sent_elmo = self.elmo(char_ids)

        return self.elmo(char_ids)['elmo_representations']

    def forward(self, batch):
        """
        We obtain lexicon embedding by concatenating word embedding with POS, NER, 
        exact match, and pre-align (word embedding of the passage. Enhanced by questions.)
        """
        drnn_input_list = []
        qrnn_input_list = []
        emb = self.embedding if self.training else self.eval_embed
        doc_tok = self.patch(batch['doc_tok'])
        doc_mask = self.patch(batch['doc_mask'])
        query_tok = self.patch(batch['query_tok'])
        query_mask = self.patch(batch['query_mask'])

        doc_emb, query_emb = emb(doc_tok), emb(query_tok)
        # Dropout on embeddings
        if self.opt['dropout_emb'] > 0:
            doc_emb = self.dropout_emb(doc_emb)
            query_emb = self.dropout_emb(query_emb)
        drnn_input_list.append(doc_emb)
        qrnn_input_list.append(query_emb)

        doc_cove_low, doc_cove_high = None, None
        query_cove_low, query_cove_high = None, None
        if self.opt['covec_on']:
            doc_cove_low, doc_cove_high = self.ContextualEmbed(
                doc_tok, doc_mask)
            query_cove_low, query_cove_high = self.ContextualEmbed(
                query_tok, query_mask)
            doc_cove_low = self.dropout(doc_cove_low)
            doc_cove_high = self.dropout(doc_cove_high)
            query_cove_low = self.dropout(query_cove_low)
            query_cove_high = self.dropout(query_cove_high)
            drnn_input_list.append(doc_cove_low)
            qrnn_input_list.append(query_cove_low)

        # elmo
        doc_elmo_low, doc_elmo_high = None, None
        query_elmo_low, query_elmo_high = None, None
        if self.opt['elmo_on']:
            doc_elmo_low, doc_elmo_high = self.get_elmo_emb(
                batch['doc_char_ids'])
            query_elmo_low, query_elmo_high = self.get_elmo_emb(
                batch['query_char_ids'])

            doc_elmo_low = self.dropout(doc_elmo_low)
            doc_elmo_high = self.dropout(doc_elmo_high)
            query_elmo_low = self.dropout(query_elmo_low)
            query_elmo_high = self.dropout(query_elmo_high)
            drnn_input_list.append(doc_elmo_low)
            qrnn_input_list.append(query_elmo_low)

        if self.opt['prealign_on']:
            q2d_atten = self.prealign(doc_emb, query_emb, query_mask)
            d2q_atten = self.prealign(query_emb, doc_emb, doc_mask)
            drnn_input_list.append(q2d_atten)
            if self.opt['prealign_bidi']:
                qrnn_input_list.append(d2q_atten)

        if self.opt['pos_on']:
            doc_pos = self.patch(batch['doc_pos'])
            doc_pos_emb = self.pos_embedding(doc_pos)
            #doc_pos_emb = self.dropout(doc_pos_emb)
            drnn_input_list.append(doc_pos_emb)

        if self.opt['ner_on']:
            doc_ner = self.patch(batch['doc_ner'])
            doc_ner_emb = self.ner_embedding(doc_ner)
            #doc_ner_emb = self.dropout(doc_ner_emb)
            drnn_input_list.append(doc_ner_emb)

        if self.opt['feat_on']:
            doc_fea = self.patch(batch['doc_fea'])
            drnn_input_list.append(doc_fea)

        doc_input = torch.cat(drnn_input_list, 2)
        query_input = torch.cat(qrnn_input_list, 2)
        if self.pwnn_on:
            doc_input = self.doc_pwnn(doc_input)
            query_input = self.que_pwnn(query_input)
        doc_input = self.dropout(doc_input)
        query_input = self.dropout(query_input)
        return doc_input, query_input, doc_emb, query_emb, doc_cove_low, doc_cove_high, query_cove_low, query_cove_high, doc_elmo_low, doc_elmo_high, query_elmo_low, query_elmo_high, doc_mask, query_mask
print ("-------------- EMBEDDING LAYER ---------------")
if (use_ELMO):
    if (load_ELMO_experiments_flag):
        options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
        weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

        print ("Loading ELMO")
        text_field_embedder = Elmo(options_file, weight_file, 2, dropout=0)
        print ("ELMO weights loaded")
else:
    text_field_embedder = TextFieldEmbedder()
    token_embedders = dict()
    text_field_embedder = Embedding(embedding_dim = 100, trainable = False)

## Parameters needed for the next layer
embedder_out_dim = text_field_embedder.get_output_dim()

print ("Embedder output dimensions: ", embedder_out_dim)
## Propagate the Batch though the Embedder
embeddings_batch_question = text_field_embedder(character_ids_question)["elmo_representations"][1]
embeddings_batch_passage = text_field_embedder( character_ids_passage)["elmo_representations"][1]

#print (embeddings_batch_question)
print ("Question representations: ", embeddings_batch_question.shape)
print ("Passage representations: ", embeddings_batch_passage.shape)
print ("Batch size: ", embeddings_batch_question.shape[0])
print ("Maximum num words in question: ", embeddings_batch_question.shape[1])
print ("Word representation dimensionality: ", embeddings_batch_question.shape[2])

batch_size = embeddings_batch_question.size(0)
passage_length = embeddings_batch_passage.size(1)
Example #21
0
class HFet(nn.Module):
    def __init__(
        self,
        label_size,
        elmo_option,
        elmo_weight,
        elmo_dropout=.5,
        repr_dropout=.2,
        dist_dropout=.5,
        latent_size=0,
        svd=None,
    ):
        super(HFet, self).__init__()
        self.label_size = label_size
        self.elmo = Elmo(elmo_option, elmo_weight, 1, dropout=elmo_dropout)
        self.elmo_dim = self.elmo.get_output_dim()

        self.attn_dim = 1
        self.attn_inner_dim = self.elmo_dim
        # Mention attention
        self.men_attn_linear_m = nn.Linear(self.elmo_dim,
                                           self.attn_inner_dim,
                                           bias=False)
        self.men_attn_linear_o = nn.Linear(self.attn_inner_dim,
                                           self.attn_dim,
                                           bias=False)
        # Context attention
        self.ctx_attn_linear_c = nn.Linear(self.elmo_dim,
                                           self.attn_inner_dim,
                                           bias=False)
        self.ctx_attn_linear_m = nn.Linear(self.elmo_dim,
                                           self.attn_inner_dim,
                                           bias=False)
        self.ctx_attn_linear_d = nn.Linear(1, self.attn_inner_dim, bias=False)
        self.ctx_attn_linear_o = nn.Linear(self.attn_inner_dim,
                                           self.attn_dim,
                                           bias=False)
        # Output linear layers
        self.repr_dropout = nn.Dropout(p=repr_dropout)
        self.output_linear = nn.Linear(self.elmo_dim * 2,
                                       label_size,
                                       bias=False)

        # SVD
        if svd:
            svd_mat = self.load_svd(svd)
            self.latent_size = svd_mat.size(1)
            self.latent_to_label.weight = nn.Parameter(svd_mat,
                                                       requires_grad=True)
            self.latent_to_label.weight.requires_grad = False
        elif latent_size == 0:
            self.latent_size = int(math.sqrt(label_size))
        else:
            self.latent_size = latent_size
        self.latent_to_label = nn.Linear(self.latent_size,
                                         label_size,
                                         bias=False)
        self.latent_scalar = nn.Parameter(torch.FloatTensor([.1]))
        self.feat_to_latent = nn.Linear(self.elmo_dim * 2,
                                        self.latent_size,
                                        bias=False)
        # Loss function
        self.criterion = nn.MultiLabelSoftMarginLoss()
        self.mse = nn.MSELoss()
        # Relative position (distance)
        self.dist_dropout = nn.Dropout(p=dist_dropout)

    def load_svd(self, path):
        print('Loading SVD matrices')
        u_file = path + '-Ut'
        s_file = path + '-S'
        with open(s_file, 'r', encoding='utf-8') as r:
            s_num = int(r.readline().rstrip())
            mat_s = [[0] * s_num for _ in range(s_num)]
            for i in range(s_num):
                mat_s[i][i] = float(r.readline().rstrip())
        mat_s = torch.FloatTensor(mat_s)

        with open(u_file, 'r', encoding='utf-8') as r:
            mat_u = []
            r.readline()
            for line in r:
                mat_u.append([float(i) for i in line.rstrip().split()])
        mat_u = torch.FloatTensor(mat_u).transpose(0, 1)
        return torch.matmul(mat_u, mat_s)  #.transpose(0, 1)

    def forward_nn(self, inputs, men_mask, ctx_mask, dist, gathers):
        # Elmo contextualized embeddings
        elmo_outputs = self.elmo(inputs)['elmo_representations'][0]
        _, seq_len, feat_dim = elmo_outputs.size()
        gathers = gathers.unsqueeze(-1).unsqueeze(-1).expand(
            -1, seq_len, feat_dim)
        elmo_outputs = torch.gather(elmo_outputs, 0, gathers)

        men_attn = self.men_attn_linear_m(elmo_outputs).tanh()
        men_attn = self.men_attn_linear_o(men_attn)
        men_attn = men_attn + (1.0 - men_mask.unsqueeze(-1)) * -10000.0
        men_attn = men_attn.softmax(1)
        men_repr = (elmo_outputs * men_attn).sum(1)

        dist = self.dist_dropout(dist)
        ctx_attn = (self.ctx_attn_linear_c(elmo_outputs) +
                    self.ctx_attn_linear_m(men_repr.unsqueeze(1)) +
                    self.ctx_attn_linear_d(dist.unsqueeze(2))).tanh()
        ctx_attn = self.ctx_attn_linear_o(ctx_attn)

        ctx_attn = ctx_attn + (1.0 - ctx_mask.unsqueeze(-1)) * -10000.0
        ctx_attn = ctx_attn.softmax(1)
        ctx_repr = (elmo_outputs * ctx_attn).sum(1)

        # Classification
        final_repr = torch.cat([men_repr, ctx_repr], dim=1)
        final_repr = self.repr_dropout(final_repr)
        outputs = self.output_linear(final_repr)

        outputs_latent = None
        latent_label = self.feat_to_latent(final_repr)  #.tanh()
        outputs_latent = self.latent_to_label(latent_label)
        outputs = outputs + self.latent_scalar * outputs_latent

        return outputs, outputs_latent

    def forward(self,
                inputs,
                labels,
                men_mask,
                ctx_mask,
                dist,
                gathers,
                inst_weights=None):
        outputs, outputs_latent = self.forward_nn(inputs, men_mask, ctx_mask,
                                                  dist, gathers)
        loss = self.criterion(outputs, labels)
        return loss

    def _prediction(self, outputs, predict_top=True):
        _, highest = outputs.max(dim=1)
        highest = highest.int().tolist()
        preds = (outputs.sigmoid() > .5).int()
        if predict_top:
            for i, h in enumerate(highest):
                preds[i][h] = 1
        return preds

    def predict(self,
                inputs,
                men_mask,
                ctx_mask,
                dist,
                gathers,
                predict_top=True):
        self.eval()
        outputs, _ = self.forward_nn(inputs, men_mask, ctx_mask, dist, gathers)
        predictions = self._prediction(outputs, predict_top=predict_top)
        self.train()
        return predictions
class ElmoTokenEmbedder(TokenEmbedder):
    """
    Compute a single layer of ELMo representations.

    This class serves as a convenience when you only want to use one layer of
    ELMo representations at the input of your network.  It's essentially a wrapper
    around Elmo(num_output_representations=1, ...)

    Parameters
    ----------
    options_file : ``str``, required.
        An ELMo JSON options file.
    weight_file : ``str``, required.
        An ELMo hdf5 weight file.
    do_layer_norm : ``bool``, optional.
        Should we apply layer normalization (passed to ``ScalarMix``)?
    dropout : ``float``, optional.
        The dropout value to be applied to the ELMo representations.
    requires_grad : ``bool``, optional
        If True, compute gradient of ELMo parameters for fine tuning.
    projection_dim : ``int``, optional
        If given, we will project the ELMo embedding down to this dimension.  We recommend that you
        try using ELMo with a lot of dropout and no projection first, but we have found a few cases
        where projection helps (particulary where there is very limited training data).
    """
    def __init__(self,
                 options_file: str,
                 weight_file: str,
                 do_layer_norm: bool = False,
                 dropout: float = 0.5,
                 requires_grad: bool = False,
                 projection_dim: int = None) -> None:
        super(ElmoTokenEmbedder, self).__init__()

        self._elmo = Elmo(options_file,
                          weight_file,
                          1,
                          do_layer_norm=do_layer_norm,
                          dropout=dropout,
                          requires_grad=requires_grad)
        if projection_dim:
            self._projection = torch.nn.Linear(self._elmo.get_output_dim(), projection_dim)
        else:
            self._projection = None

    def get_output_dim(self):
        return self._elmo.get_output_dim()

    def forward(self, inputs: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``
            Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch.

        Returns
        -------
        The ELMo representations for the input sequence, shape
        ``(batch_size, timesteps, embedding_dim)``
        """
        elmo_output = self._elmo(inputs)
        elmo_representations = elmo_output['elmo_representations'][0]
        if self._projection:
            projection = self._projection
            for _ in range(elmo_representations.dim() - 2):
                projection = TimeDistributed(projection)
            elmo_representations = projection(elmo_representations)
        return elmo_representations

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'ElmoTokenEmbedder':
        params.add_file_to_archive('options_file')
        params.add_file_to_archive('weight_file')
        options_file = params.pop('options_file')
        weight_file = params.pop('weight_file')
        requires_grad = params.pop('requires_grad', False)
        do_layer_norm = params.pop_bool('do_layer_norm', False)
        dropout = params.pop_float("dropout", 0.5)
        projection_dim = params.pop_int("projection_dim", None)
        params.assert_empty(cls.__name__)
        return cls(options_file=options_file,
                   weight_file=weight_file,
                   do_layer_norm=do_layer_norm,
                   dropout=dropout,
                   requires_grad=requires_grad,
                   projection_dim=projection_dim)
print("-------------- EMBEDDING LAYER ---------------")
if (use_ELMO):
    if (load_ELMO_experiments_flag):
        options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
        weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

        print("Loading ELMO")
        text_field_embedder = Elmo(options_file, weight_file, 2, dropout=0)
        print("ELMO weights loaded")
else:
    text_field_embedder = TextFieldEmbedder()
    token_embedders = dict()
    text_field_embedder = Embedding(embedding_dim=100, trainable=False)

## Parameters needed for the next layer
embedder_out_dim = text_field_embedder.get_output_dim()

print("Embedder output dimensions: ", embedder_out_dim)
## Propagate the Batch though the Embedder
embeddings_batch_question = text_field_embedder(
    character_ids_question)["elmo_representations"][1]
embeddings_batch_passage = text_field_embedder(
    character_ids_passage)["elmo_representations"][1]

#print (embeddings_batch_question)
print("Question representations: ", embeddings_batch_question.shape)
print("Passage representations: ", embeddings_batch_passage.shape)
print("Batch size: ", embeddings_batch_question.shape[0])
print("Maximum num words in question: ", embeddings_batch_question.shape[1])
print("Word representation dimensionality: ",
      embeddings_batch_question.shape[2])
Example #24
0
class ElmoTokenEmbedder(TokenEmbedder):
    """
    Compute a single layer of ELMo representations.

    This class serves as a convenience when you only want to use one layer of
    ELMo representations at the input of your network.  It's essentially a wrapper
    around Elmo(num_output_representations=1, ...)

    Parameters
    ----------
    options_file : ``str``, required.
        An ELMo JSON options file.
    weight_file : ``str``, required.
        An ELMo hdf5 weight file.
    do_layer_norm : ``bool``, optional.
        Should we apply layer normalization (passed to ``ScalarMix``)?
    dropout : ``float``, optional.
        The dropout value to be applied to the ELMo representations.
    requires_grad : ``bool``, optional
        If True, compute gradient of ELMo parameters for fine tuning.
    projection_dim : ``int``, optional
        If given, we will project the ELMo embedding down to this dimension.  We recommend that you
        try using ELMo with a lot of dropout and no projection first, but we have found a few cases
        where projection helps (particulary where there is very limited training data).
    vocab_to_cache : ``List[str]``, optional, (default = 0.5).
        A list of words to pre-compute and cache character convolutions
        for. If you use this option, the ElmoTokenEmbedder expects that you pass word
        indices of shape (batch_size, timesteps) to forward, instead
        of character indices. If you use this option and pass a word which
        wasn't pre-cached, this will break.
    """
    def __init__(self,
                 options_file: str,
                 weight_file: str,
                 do_layer_norm: bool = False,
                 dropout: float = 0.5,
                 requires_grad: bool = False,
                 projection_dim: int = None,
                 vocab_to_cache: List[str] = None) -> None:
        super(ElmoTokenEmbedder, self).__init__()

        self._elmo = Elmo(options_file,
                          weight_file,
                          1,
                          do_layer_norm=do_layer_norm,
                          dropout=dropout,
                          requires_grad=requires_grad,
                          vocab_to_cache=vocab_to_cache)
        if projection_dim:
            self._projection = torch.nn.Linear(self._elmo.get_output_dim(), projection_dim)
        else:
            self._projection = None

    def get_output_dim(self):
        return self._elmo.get_output_dim()

    def forward(self, # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                word_inputs: torch.Tensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``
            Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch.
        word_inputs : ``torch.Tensor``, optional.
            If you passed a cached vocab, you can in addition pass a tensor of shape
            ``(batch_size, timesteps)``, which represent word ids which have been pre-cached.

        Returns
        -------
        The ELMo representations for the input sequence, shape
        ``(batch_size, timesteps, embedding_dim)``
        """
        elmo_output = self._elmo(inputs, word_inputs)
        elmo_representations = elmo_output['elmo_representations'][0]
        if self._projection:
            projection = self._projection
            for _ in range(elmo_representations.dim() - 2):
                projection = TimeDistributed(projection)
            elmo_representations = projection(elmo_representations)
        return elmo_representations

    # Custom vocab_to_cache logic requires a from_params implementation.
    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'ElmoTokenEmbedder':  # type: ignore
        # pylint: disable=arguments-differ
        params.add_file_to_archive('options_file')
        params.add_file_to_archive('weight_file')
        options_file = params.pop('options_file')
        weight_file = params.pop('weight_file')
        requires_grad = params.pop('requires_grad', False)
        do_layer_norm = params.pop_bool('do_layer_norm', False)
        dropout = params.pop_float("dropout", 0.5)
        namespace_to_cache = params.pop("namespace_to_cache", None)
        if namespace_to_cache is not None:
            vocab_to_cache = list(vocab.get_token_to_index_vocabulary(namespace_to_cache).keys())
        else:
            vocab_to_cache = None
        projection_dim = params.pop_int("projection_dim", None)
        params.assert_empty(cls.__name__)
        return cls(options_file=options_file,
                   weight_file=weight_file,
                   do_layer_norm=do_layer_norm,
                   dropout=dropout,
                   requires_grad=requires_grad,
                   projection_dim=projection_dim,
                   vocab_to_cache=vocab_to_cache)
class ElmoTokenEmbedder(TokenEmbedder):
    u"""
    Compute a single layer of ELMo representations.

    This class serves as a convenience when you only want to use one layer of
    ELMo representations at the input of your network.  It's essentially a wrapper
    around Elmo(num_output_representations=1, ...)

    Parameters
    ----------
    options_file : ``str``, required.
        An ELMo JSON options file.
    weight_file : ``str``, required.
        An ELMo hdf5 weight file.
    do_layer_norm : ``bool``, optional.
        Should we apply layer normalization (passed to ``ScalarMix``)?
    dropout : ``float``, optional.
        The dropout value to be applied to the ELMo representations.
    requires_grad : ``bool``, optional
        If True, compute gradient of ELMo parameters for fine tuning.
    projection_dim : ``int``, optional
        If given, we will project the ELMo embedding down to this dimension.  We recommend that you
        try using ELMo with a lot of dropout and no projection first, but we have found a few cases
        where projection helps (particulary where there is very limited training data).
    vocab_to_cache : ``List[str]``, optional, (default = 0.5).
        A list of words to pre-compute and cache character convolutions
        for. If you use this option, the ElmoTokenEmbedder expects that you pass word
        indices of shape (batch_size, timesteps) to forward, instead
        of character indices. If you use this option and pass a word which
        wasn't pre-cached, this will break.
    """
    def __init__(self,
                 options_file,
                 weight_file,
                 do_layer_norm=False,
                 dropout=0.5,
                 requires_grad=False,
                 projection_dim=None,
                 vocab_to_cache=None):
        super(ElmoTokenEmbedder, self).__init__()

        self._elmo = Elmo(options_file,
                          weight_file,
                          1,
                          do_layer_norm=do_layer_norm,
                          dropout=dropout,
                          requires_grad=requires_grad,
                          vocab_to_cache=vocab_to_cache)
        if projection_dim:
            self._projection = torch.nn.Linear(self._elmo.get_output_dim(),
                                               projection_dim)
        else:
            self._projection = None

    def get_output_dim(self):
        return self._elmo.get_output_dim()

    def forward(
            self,  # pylint: disable=arguments-differ
            inputs,
            word_inputs=None):
        u"""
        Parameters
        ----------
        inputs: ``torch.Tensor``
            Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch.
        word_inputs : ``torch.Tensor``, optional.
            If you passed a cached vocab, you can in addition pass a tensor of shape
            ``(batch_size, timesteps)``, which represent word ids which have been pre-cached.

        Returns
        -------
        The ELMo representations for the input sequence, shape
        ``(batch_size, timesteps, embedding_dim)``
        """
        elmo_output = self._elmo(inputs, word_inputs)
        elmo_representations = elmo_output[u'elmo_representations'][0]
        if self._projection:
            projection = self._projection
            for _ in range(elmo_representations.dim() - 2):
                projection = TimeDistributed(projection)
            elmo_representations = projection(elmo_representations)
        return elmo_representations

    # Custom vocab_to_cache logic requires a from_params implementation.
    @classmethod
    def from_params(cls, vocab, params):  # type: ignore
        # pylint: disable=arguments-differ
        params.add_file_to_archive(u'options_file')
        params.add_file_to_archive(u'weight_file')
        options_file = params.pop(u'options_file')
        weight_file = params.pop(u'weight_file')
        requires_grad = params.pop(u'requires_grad', False)
        do_layer_norm = params.pop_bool(u'do_layer_norm', False)
        dropout = params.pop_float(u"dropout", 0.5)
        namespace_to_cache = params.pop(u"namespace_to_cache", None)
        if namespace_to_cache is not None:
            vocab_to_cache = list(
                vocab.get_token_to_index_vocabulary(namespace_to_cache).keys())
        else:
            vocab_to_cache = None
        projection_dim = params.pop_int(u"projection_dim", None)
        params.assert_empty(cls.__name__)
        return cls(options_file=options_file,
                   weight_file=weight_file,
                   do_layer_norm=do_layer_norm,
                   dropout=dropout,
                   requires_grad=requires_grad,
                   projection_dim=projection_dim,
                   vocab_to_cache=vocab_to_cache)
Example #26
0
class ElmoTokenEmbedder(TokenEmbedder):
    """
    Compute a single layer of ELMo representations.

    This class serves as a convenience when you only want to use one layer of
    ELMo representations at the input of your network.  It's essentially a wrapper
    around Elmo(num_output_representations=1, ...)

    Parameters
    ----------
    options_file : ``str``, required.
        An ELMo JSON options file.
    weight_file : ``str``, required.
        An ELMo hdf5 weight file.
    do_layer_norm : ``bool``, optional.
        Should we apply layer normalization (passed to ``ScalarMix``)?
    dropout : ``float``, optional.
        The dropout value to be applied to the ELMo representations.
    requires_grad : ``bool``, optional
        If True, compute gradient of ELMo parameters for fine tuning.
    projection_dim : ``int``, optional
        If given, we will project the ELMo embedding down to this dimension.  We recommend that you
        try using ELMo with a lot of dropout and no projection first, but we have found a few cases
        where projection helps (particulary where there is very limited training data).
    vocab_to_cache : ``List[str]``, optional, (default = 0.5).
        A list of words to pre-compute and cache character convolutions
        for. If you use this option, the ElmoTokenEmbedder expects that you pass word
        indices of shape (batch_size, timesteps) to forward, instead
        of character indices. If you use this option and pass a word which
        wasn't pre-cached, this will break.
    """
    def __init__(self,
                 options_file: str,
                 weight_file: str,
                 do_layer_norm: bool = False,
                 dropout: float = 0.5,
                 requires_grad: bool = False,
                 projection_dim: int = None,
                 vocab_to_cache: List[str] = None) -> None:
        super(ElmoTokenEmbedder, self).__init__()

        self._elmo = Elmo(options_file,
                          weight_file,
                          1,
                          do_layer_norm=do_layer_norm,
                          dropout=dropout,
                          requires_grad=requires_grad,
                          vocab_to_cache=vocab_to_cache)
        if projection_dim:
            self._projection = torch.nn.Linear(self._elmo.get_output_dim(),
                                               projection_dim)
        else:
            self._projection = None

    def get_output_dim(self):
        return self._elmo.get_output_dim()

    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: torch.Tensor,
            word_inputs: torch.Tensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``
            Shape ``(batch_size, timesteps, 50)`` of character ids representing the current batch.
        word_inputs : ``torch.Tensor``, optional.
            If you passed a cached vocab, you can in addition pass a tensor of shape
            ``(batch_size, timesteps)``, which represent word ids which have been pre-cached.

        Returns
        -------
        The ELMo representations for the input sequence, shape
        ``(batch_size, timesteps, embedding_dim)``
        """
        elmo_output = self._elmo(inputs, word_inputs)
        elmo_representations = elmo_output['elmo_representations'][0]
        if self._projection:
            projection = self._projection
            for _ in range(elmo_representations.dim() - 2):
                projection = TimeDistributed(projection)
            elmo_representations = projection(elmo_representations)
        return elmo_representations
Example #27
0
class ElmoTokenEmbedder(TokenEmbedder):
    """
    Compute a single layer of ELMo representations.

    This class serves as a convenience when you only want to use one layer of
    ELMo representations at the input of your network.  It's essentially a wrapper
    around Elmo(num_output_representations=1, ...)

    Registered as a `TokenEmbedder` with name "elmo_token_embedder".

    # Parameters

    options_file : `str`, required.
        An ELMo JSON options file.
    weight_file : `str`, required.
        An ELMo hdf5 weight file.
    do_layer_norm : `bool`, optional.
        Should we apply layer normalization (passed to `ScalarMix`)?
    dropout : `float`, optional, (default = `0.5`).
        The dropout value to be applied to the ELMo representations.
    requires_grad : `bool`, optional
        If True, compute gradient of ELMo parameters for fine tuning.
    projection_dim : `int`, optional
        If given, we will project the ELMo embedding down to this dimension.  We recommend that you
        try using ELMo with a lot of dropout and no projection first, but we have found a few cases
        where projection helps (particularly where there is very limited training data).
    vocab_to_cache : `List[str]`, optional.
        A list of words to pre-compute and cache character convolutions
        for. If you use this option, the ElmoTokenEmbedder expects that you pass word
        indices of shape (batch_size, timesteps) to forward, instead
        of character indices. If you use this option and pass a word which
        wasn't pre-cached, this will break.
    scalar_mix_parameters : `List[int]`, optional, (default=`None`)
        If not `None`, use these scalar mix parameters to weight the representations
        produced by different layers. These mixing weights are not updated during
        training. The mixing weights here should be the unnormalized (i.e., pre-softmax)
        weights. So, if you wanted to use only the 1st layer of a 2-layer ELMo,
        you can set this to [-9e10, 1, -9e10 ].
    """
    def __init__(
        self,
        options_file:
        str = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/"
        + "elmo_2x4096_512_2048cnn_2xhighway_options.json",
        weight_file:
        str = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/"
        + "elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5",
        do_layer_norm: bool = False,
        dropout: float = 0.5,
        requires_grad: bool = False,
        projection_dim: int = None,
        vocab_to_cache: List[str] = None,
        scalar_mix_parameters: List[float] = None,
    ) -> None:
        super().__init__()

        self._elmo = Elmo(
            options_file,
            weight_file,
            1,
            do_layer_norm=do_layer_norm,
            dropout=dropout,
            requires_grad=requires_grad,
            vocab_to_cache=vocab_to_cache,
            scalar_mix_parameters=scalar_mix_parameters,
        )
        if projection_dim:
            self._projection = torch.nn.Linear(self._elmo.get_output_dim(),
                                               projection_dim)
            self.output_dim = projection_dim
        else:
            self._projection = None
            self.output_dim = self._elmo.get_output_dim()

    def get_output_dim(self) -> int:
        return self.output_dim

    def forward(self,
                elmo_tokens: torch.Tensor,
                word_inputs: torch.Tensor = None) -> torch.Tensor:
        """
        # Parameters

        elmo_tokens : `torch.Tensor`
            Shape `(batch_size, timesteps, 50)` of character ids representing the current batch.
        word_inputs : `torch.Tensor`, optional.
            If you passed a cached vocab, you can in addition pass a tensor of shape
            `(batch_size, timesteps)`, which represent word ids which have been pre-cached.

        # Returns

        `torch.Tensor`
            The ELMo representations for the input sequence, shape
            `(batch_size, timesteps, embedding_dim)`
        """
        elmo_output = self._elmo(elmo_tokens, word_inputs)
        elmo_representations = elmo_output["elmo_representations"][0]
        if self._projection:
            projection = self._projection
            for _ in range(elmo_representations.dim() - 2):
                projection = TimeDistributed(projection)
            elmo_representations = projection(elmo_representations)
        return elmo_representations
class PretrainedInputEmbeddings(nn.Module):
    """Construct the embeddings from pre-trained ELMo\BERT\XLNet embeddings
    1) pretrained_model_type == 'tf', pretrained_model_info = {'type': 'bert', 'name': 'bert-base-uncased'}
        1.1): bert, bert-base-uncased, bert-base-cased, bert-base-chinese
        1.2): xlnet, xlnet-base-cased
    2) pretrained_model_type == 'elmo', pretrained_model_info = {'elmo_json': '', 'elmo_weight': ''}
    """
    def __init__(self,
                 pretrained_model_type='tf',
                 pretrained_model_info={},
                 dropout=0.0,
                 device=None):
        super(PretrainedInputEmbeddings, self).__init__()

        self.pretrained_model_type = pretrained_model_type.lower()
        self.pretrained_model_info = pretrained_model_info
        self.device = device

        assert self.pretrained_model_type in {'tf', 'elmo'}
        if self.pretrained_model_type == 'tf':
            if 'uncased' in pretrained_model_info['name']:
                input_word_lowercase = True
            else:
                input_word_lowercase = False
            self.tf_tokenizer, self.tf_model = load_pretrained_transformer(
                self.pretrained_model_info['type'],
                self.pretrained_model_info['name'],
                lowercase=input_word_lowercase)
            #self.tf_model.embeddings.word_embeddings = nn.Embedding(6500, 768, padding_idx=0)
            #self.tf_model.encoder.layer = self.tf_model.encoder.layer[:2]
            self.tf_input_args = {
                'cls_token_at_end':
                bool(self.pretrained_model_info['type'] in
                     ['xlnet']),  # xlnet has a cls token at the end
                'cls_token_segment_id':
                2 if self.pretrained_model_info['type'] in ['xlnet'] else 0,
                'pad_on_left':
                bool(self.pretrained_model_info['type'] in
                     ['xlnet']),  # pad on the left for xlnet
                'pad_token_segment_id':
                4 if self.pretrained_model_info['type'] in ['xlnet'] else 0,
            }
            self.embedding_dim = self.tf_model.config.hidden_size
        else:
            self.elmo_model = Elmo(pretrained_model_info['elmo_json'],
                                   pretrained_model_info['elmo_weight'],
                                   1,
                                   dropout=0)
            self.embedding_dim = self.elmo_model.get_output_dim()

        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, words, no_dropout=False):
        """
        words: a list of word list
        """
        if self.pretrained_model_type == 'tf':
            lengths = [len(ws) for ws in words]
            input_tf_ids, tf_segment_ids, tf_attention_mask, tf_output_selects, tf_output_copies = prepare_inputs_for_bert_xlnet(
                words,
                self.tf_tokenizer,
                device=self.device,
                **self.tf_input_args)
            input_tf = {
                "input_ids": input_tf_ids,
                "segment_ids": tf_segment_ids,
                "attention_mask": tf_attention_mask,
                "selects": tf_output_selects,
                "copies": tf_output_copies,
                "batch_size": len(lengths),
                "max_word_length": max(lengths)
            }
            embeds = transformer_forward_by_ignoring_suffix(self.tf_model,
                                                            **input_tf,
                                                            device=self.device)
        else:
            tokens = batch_to_ids(words).to(self.device)
            elmo_embeds = self.elmo_model(tokens)
            embeds = elmo_embeds['elmo_representations'][0]

        if not no_dropout:
            embeds = self.dropout_layer(embeds)

        return embeds
Example #29
0
class PretrainedInputEmbeddings(nn.Module):
    """Construct the embeddings from pre-trained ELMo\BERT\XLNet embeddings
    1) pretrained_model_type == 'tf', pretrained_model_info = {'type': 'bert', 'name': 'bert-base-uncased', 'alignment': 'first'}
        1.1): bert, bert-base-uncased, bert-base-cased, bert-base-chinese
        1.2): xlnet, xlnet-base-cased
        1.3): alignemnt can be None, 'first' and 'avg'
    2) pretrained_model_type == 'elmo', pretrained_model_info = {'elmo_json': '', 'elmo_weight': ''}
    """
    def __init__(self,
                 pretrained_model_type='tf',
                 pretrained_model_info={},
                 dropout=0.0,
                 device=None):
        super(PretrainedInputEmbeddings, self).__init__()

        self.pretrained_model_type = pretrained_model_type.lower()
        self.pretrained_model_info = pretrained_model_info
        self.device = device

        assert self.pretrained_model_type in {'tf', 'elmo'}
        if self.pretrained_model_type == 'tf':
            if 'uncased' in pretrained_model_info['name']:
                input_word_lowercase = True
            else:
                input_word_lowercase = False
            self.tf_tokenizer, self.tf_model = load_pretrained_transformer(
                self.pretrained_model_info['type'],
                self.pretrained_model_info['name'],
                lowercase=input_word_lowercase)
            #self.tf_model.embeddings.word_embeddings = nn.Embedding(6500, 768, padding_idx=0)
            #self.tf_model.encoder.layer = self.tf_model.encoder.layer[:2]
            self.tf_input_args = {
                'cls_token_at_end':
                bool(self.pretrained_model_info['type'] in
                     ['xlnet']),  # xlnet has a cls token at the end
                'cls_token_segment_id':
                2 if self.pretrained_model_info['type'] in ['xlnet'] else 0,
                'pad_on_left':
                bool(self.pretrained_model_info['type'] in
                     ['xlnet']),  # pad on the left for xlnet
                'pad_token_segment_id':
                4 if self.pretrained_model_info['type'] in ['xlnet'] else 0,
            }
            if 'alignment' not in self.pretrained_model_info or self.pretrained_model_info[
                    'alignment'] not in {'first', 'avg', 'ori'}:
                self.alignment = None
            else:
                self.alignment = self.pretrained_model_info['alignment']
            self.embedding_dim = self.tf_model.config.hidden_size
        else:
            self.elmo_model = Elmo(pretrained_model_info['elmo_json'],
                                   pretrained_model_info['elmo_weight'],
                                   1,
                                   dropout=0)
            self.embedding_dim = self.elmo_model.get_output_dim()

        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, words, no_dropout=False):
        """
        words: a list of word list
        """
        """
        [NOTE]: If you want to feed output word embeddings into RNN/GRU/LSTM by using pack_padded_sequence, you'd better sort 'words' by length in advance.
        """
        if self.pretrained_model_type == 'tf':
            input_tf, tf_tokens, output_tokens, output_token_lengths = prepare_inputs_for_bert_xlnet(
                words,
                self.tf_tokenizer,
                device=self.device,
                **self.tf_input_args,
                alignment=self.alignment)
            embeds = transformer_forward_by_ignoring_suffix(
                self.tf_model,
                **input_tf,
                device=self.device,
                alignment=self.alignment)
        else:
            tokens = batch_to_ids(words).to(self.device)
            elmo_embeds = self.elmo_model(tokens)
            embeds = elmo_embeds['elmo_representations'][0]
            output_tokens = words
            output_token_lengths = [len(ws) for ws in words]

        if not no_dropout:
            embeds = self.dropout_layer(embeds)

        return embeds, output_tokens, output_token_lengths
Example #30
0
class ContextualControllerELMo(ControllerBase):
    def __init__(
            self,
            hidden_size,
            dropout,
            pretrained_embeddings_dir,
            dataset_name,
            fc_hidden_size=150,
            freeze_pretrained=True,
            learning_rate=0.001,
            layer_learning_rate: Optional[Dict[str, float]] = None,
            max_segment_size=None,  # if None, process sentences independently
            max_span_size=10,
            model_name=None):
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.freeze_pretrained = freeze_pretrained
        self.fc_hidden_size = fc_hidden_size
        self.max_span_size = max_span_size
        self.max_segment_size = max_segment_size
        self.learning_rate = learning_rate
        self.layer_learning_rate = layer_learning_rate if layer_learning_rate is not None else {}

        self.pretrained_embeddings_dir = pretrained_embeddings_dir
        self.embedder = Elmo(
            options_file=os.path.join(pretrained_embeddings_dir,
                                      "options.json"),
            weight_file=os.path.join(pretrained_embeddings_dir,
                                     "slovenian-elmo-weights.hdf5"),
            dropout=(0.0 if freeze_pretrained else dropout),
            num_output_representations=1,
            requires_grad=(not freeze_pretrained)).to(DEVICE)
        embedding_size = self.embedder.get_output_dim()

        self.context_encoder = nn.LSTM(input_size=embedding_size,
                                       hidden_size=hidden_size,
                                       batch_first=True,
                                       bidirectional=True).to(DEVICE)
        self.scorer = NeuralCoreferencePairScorer(num_features=(2 *
                                                                hidden_size),
                                                  hidden_size=fc_hidden_size,
                                                  dropout=dropout).to(DEVICE)
        params_to_update = [{
            "params":
            self.scorer.parameters(),
            "lr":
            self.layer_learning_rate.get("lr_scorer", self.learning_rate)
        }, {
            "params":
            self.context_encoder.parameters(),
            "lr":
            self.layer_learning_rate.get("lr_context_encoder",
                                         self.learning_rate)
        }]
        if not freeze_pretrained:
            params_to_update.append({
                "params":
                self.embedder.parameters(),
                "lr":
                self.layer_learning_rate.get("lr_embedder", self.learning_rate)
            })

        self.optimizer = optim.Adam(params_to_update, lr=self.learning_rate)

        super().__init__(learning_rate=learning_rate,
                         dataset_name=dataset_name,
                         model_name=model_name)
        logging.info(
            f"Initialized contextual ELMo-based model with name {self.model_name}."
        )

    @property
    def model_base_dir(self):
        return "contextual_model_elmo"

    def train_mode(self):
        if not self.freeze_pretrained:
            self.embedder.train()
        self.context_encoder.train()
        self.scorer.train()

    def eval_mode(self):
        self.embedder.eval()
        self.context_encoder.eval()
        self.scorer.eval()

    def load_checkpoint(self):
        self.loaded_from_file = True
        self.context_encoder.load_state_dict(
            torch.load(os.path.join(self.path_model_dir, "context_encoder.th"),
                       map_location=DEVICE))
        self.scorer.load_state_dict(
            torch.load(os.path.join(self.path_model_dir, "scorer.th"),
                       map_location=DEVICE))

        path_to_embeddings = os.path.join(self.path_model_dir, "embeddings.th")
        if os.path.isfile(path_to_embeddings):
            logging.info(
                f"Loading fine-tuned ELMo weights from '{path_to_embeddings}'")
            self.embedder.load_state_dict(
                torch.load(path_to_embeddings, map_location=DEVICE))

    @staticmethod
    def from_pretrained(model_dir):
        controller_config_path = os.path.join(model_dir,
                                              "controller_config.json")
        with open(controller_config_path, "r", encoding="utf-8") as f_config:
            pre_config = json.load(f_config)

        instance = ContextualControllerELMo(**pre_config)
        instance.load_checkpoint()

        return instance

    def save_pretrained(self, model_dir):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        # Write controller config (used for instantiation)
        controller_config_path = os.path.join(model_dir,
                                              "controller_config.json")
        with open(controller_config_path, "w", encoding="utf-8") as f_config:
            json.dump(
                {
                    "hidden_size": self.hidden_size,
                    "dropout": self.dropout,
                    "pretrained_embeddings_dir":
                    self.pretrained_embeddings_dir,
                    "dataset_name": self.dataset_name,
                    "fc_hidden_size": self.fc_hidden_size,
                    "freeze_pretrained": self.freeze_pretrained,
                    "learning_rate": self.learning_rate,
                    "layer_learning_rate": self.layer_learning_rate,
                    "max_segment_size": self.max_segment_size,
                    "max_span_size": self.max_span_size,
                    "model_name": self.model_name
                },
                fp=f_config,
                indent=4)

        torch.save(self.context_encoder.state_dict(),
                   os.path.join(self.path_model_dir, "context_encoder.th"))
        torch.save(self.scorer.state_dict(),
                   os.path.join(self.path_model_dir, "scorer.th"))

        # Save fine-tuned ELMo embeddings only if they're not frozen
        if not self.freeze_pretrained:
            torch.save(self.embedder.state_dict(),
                       os.path.join(self.path_model_dir, "embeddings.th"))

    def save_checkpoint(self):
        logging.warning(
            "save_checkpoint() is deprecated. Use save_pretrained() instead")
        self.save_pretrained(self.path_model_dir)

    def _prepare_doc(self, curr_doc: Document) -> Dict:
        """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since
        data inside same document does not get shuffled. """
        ret = {}

        # By default, each sentence is its own segment, meaning sentences are processed independently
        if self.max_segment_size is None:

            def get_position(t):
                return t.sentence_index, t.position_in_sentence

            _encoded_segments = batch_to_ids(curr_doc.raw_sentences())
        # Optionally, one can specify max_segment_size, in which case segments of tokens are processed independently
        else:

            def get_position(t):
                doc_position = t.position_in_document
                return doc_position // self.max_segment_size, doc_position % self.max_segment_size

            flattened_doc = list(chain(*curr_doc.raw_sentences()))
            num_segments = (len(flattened_doc) + self.max_segment_size -
                            1) // self.max_segment_size
            _encoded_segments = \
                batch_to_ids([flattened_doc[idx_seg * self.max_segment_size: (idx_seg + 1) * self.max_segment_size]
                              for idx_seg in range(num_segments)])

        encoded_segments = []
        # Convention: Add a PAD word ([0] * max_chars vector) at the end of each segment, for padding mentions
        for curr_sent in _encoded_segments:
            encoded_segments.append(
                torch.cat((curr_sent,
                           torch.zeros(
                               (1, ELMoCharacterMapper.max_word_length),
                               dtype=torch.long))))
        encoded_segments = torch.stack(encoded_segments)

        cluster_sets = []
        mention_to_cluster_id = {}
        for i, curr_cluster in enumerate(curr_doc.clusters):
            cluster_sets.append(set(curr_cluster))
            for mid in curr_cluster:
                mention_to_cluster_id[mid] = i

        all_candidate_data = []
        for idx_head, (head_id,
                       head_mention) in enumerate(curr_doc.mentions.items(),
                                                  1):
            gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]]

            # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`)
            candidates, candidate_data = [None], []
            candidate_attention = []
            correct_antecedents = []

            curr_head_data = [[], []]
            num_head_words = 0
            for curr_token in head_mention.tokens:
                idx_segment, idx_inside_segment = get_position(curr_token)
                curr_head_data[0].append(idx_segment)
                curr_head_data[1].append(idx_inside_segment)
                num_head_words += 1

            if num_head_words > self.max_span_size:
                curr_head_data[0] = curr_head_data[0][:self.max_span_size]
                curr_head_data[1] = curr_head_data[1][:self.max_span_size]
            else:
                curr_head_data[0] += [curr_head_data[0][-1]
                                      ] * (self.max_span_size - num_head_words)
                curr_head_data[1] += [-1
                                      ] * (self.max_span_size - num_head_words)

            head_attention = torch.ones((1, self.max_span_size),
                                        dtype=torch.bool)
            head_attention[0, num_head_words:] = False

            for idx_candidate, (cand_id, cand_mention) in enumerate(
                    curr_doc.mentions.items(), start=1):
                if idx_candidate >= idx_head:
                    break

                candidates.append(cand_id)

                # Maps tokens to positions inside segments (idx_seg, idx_inside_seg) for efficient indexing later
                curr_candidate_data = [[], []]
                num_candidate_words = 0
                for curr_token in cand_mention.tokens:
                    idx_segment, idx_inside_segment = get_position(curr_token)
                    curr_candidate_data[0].append(idx_segment)
                    curr_candidate_data[1].append(idx_inside_segment)
                    num_candidate_words += 1

                if num_candidate_words > self.max_span_size:
                    curr_candidate_data[0] = curr_candidate_data[
                        0][:self.max_span_size]
                    curr_candidate_data[1] = curr_candidate_data[
                        1][:self.max_span_size]
                else:
                    # padding tokens index into the PAD token of the last segment
                    curr_candidate_data[0] += [curr_candidate_data[0][-1]] * (
                        self.max_span_size - num_candidate_words)
                    curr_candidate_data[1] += [-1] * (self.max_span_size -
                                                      num_candidate_words)

                candidate_data.append(curr_candidate_data)
                curr_attention = torch.ones((1, self.max_span_size),
                                            dtype=torch.bool)
                curr_attention[0, num_candidate_words:] = False
                candidate_attention.append(curr_attention)

                is_coreferent = cand_id in gt_antecedent_ids
                if is_coreferent:
                    correct_antecedents.append(idx_candidate)

            if len(correct_antecedents) == 0:
                correct_antecedents.append(0)

            candidate_attention = torch.cat(
                candidate_attention) if len(candidate_attention) > 0 else []
            all_candidate_data.append({
                "head_id":
                head_id,
                "head_data":
                torch.tensor([curr_head_data]),
                "head_attention":
                head_attention,
                "candidates":
                candidates,
                "candidate_data":
                torch.tensor(candidate_data),
                "candidate_attention":
                candidate_attention,
                "correct_antecedents":
                correct_antecedents
            })

        ret["preprocessed_segments"] = encoded_segments
        ret["steps"] = all_candidate_data

        return ret

    def _train_doc(self, curr_doc, eval_mode=False):
        """ Trains/evaluates (if `eval_mode` is True) model on specific document.
            Returns predictions, loss and number of examples evaluated. """

        if len(curr_doc.mentions) == 0:
            return {}, (0.0, 0)

        if not hasattr(curr_doc, "_cache_elmo"):
            curr_doc._cache_elmo = self._prepare_doc(curr_doc)
        cache = curr_doc._cache_elmo  # type: Dict

        encoded_segments = cache["preprocessed_segments"]
        if self.freeze_pretrained:
            with torch.no_grad():
                res = self.embedder(encoded_segments.to(DEVICE))
        else:
            res = self.embedder(encoded_segments.to(DEVICE))

        # Note: max_segment_size is either specified at instantiation or (the length of longest sentence + 1)
        embedded_segments = res["elmo_representations"][
            0]  # [num_segments, max_segment_size, embedding_size]
        (lstm_segments, _) = self.context_encoder(
            embedded_segments
        )  # [num_segments, max_segment_size, 2 * hidden_size]

        doc_loss, n_examples = 0.0, len(cache["steps"])
        preds = {}

        for curr_step in cache["steps"]:
            head_id = curr_step["head_id"]
            head_data = curr_step["head_data"]

            candidates = curr_step["candidates"]
            candidate_data = curr_step["candidate_data"]
            correct_antecedents = curr_step["correct_antecedents"]

            # Note: num_candidates includes dummy antecedent + actual candidates
            num_candidates = len(candidates)
            if num_candidates == 1:
                curr_pred = 0
            else:
                idx_segment = candidate_data[:, 0, :]
                idx_in_segment = candidate_data[:, 1, :]

                # [num_candidates, max_span_size, embedding_size]
                candidate_data = lstm_segments[idx_segment, idx_in_segment]
                # [1, head_size, embedding_size]
                head_data = lstm_segments[head_data[:, 0, :], head_data[:,
                                                                        1, :]]
                head_data = head_data.repeat((num_candidates - 1, 1, 1))

                candidate_scores = self.scorer(
                    candidate_data, head_data,
                    curr_step["candidate_attention"],
                    curr_step["head_attention"].repeat(
                        (num_candidates - 1, 1)))

                # [1, num_candidates]
                candidate_scores = torch.cat(
                    (torch.tensor([0.0], device=DEVICE),
                     candidate_scores.flatten())).unsqueeze(0)

                curr_pred = torch.argmax(candidate_scores)
                doc_loss += self.loss(
                    candidate_scores.repeat((len(correct_antecedents), 1)),
                    torch.tensor(correct_antecedents, device=DEVICE))

            # { antecedent: [mention(s)] } pair
            existing_refs = preds.get(candidates[int(curr_pred)], [])
            existing_refs.append(head_id)
            preds[candidates[int(curr_pred)]] = existing_refs

        if not eval_mode:
            doc_loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

        return preds, (float(doc_loss), n_examples)