def _pooling_avg(self, encoded_sentences, mask, len_seq_sent):
        mask = mask.unsqueeze(3).repeat(1, 1, 1, encoded_sentences.size(3))
        encoded_sentences = encoded_sentences * mask
        len_seq_sent = utils.cast_type(len_seq_sent, FLOAT, self.use_gpu)
        encoded_sentences = torch.div(torch.sum(encoded_sentences, dim=2), len_seq_sent.unsqueeze(2))

        return encoded_sentences, mask
def update_cp_seq(len_seq, len_sents, config, cur_cp_ind, text_inputs,
                  cp_seq_list):
    sent_mask = torch.sign(len_sents)  # (batch_size, len_sent)
    sent_mask = utils.cast_type(sent_mask, FLOAT, config.use_gpu)
    num_sents = sent_mask.sum(dim=1)  # (batch_size)
    for batch_ind, cur_doc_cp in enumerate(cur_cp_ind):
        cur_doc_len = int(len_seq[batch_ind])
        cur_sent_len = len_sents[batch_ind]
        cur_sent_num = int(num_sents[batch_ind])  # (max_sent_num)
        end_loc = 0
        prev_loc = 0

        cur_cp_seq = []
        cur_text_seq = text_inputs[batch_ind]
        for cur_sent_ind in range(cur_sent_num):
            end_loc += int(cur_sent_len[cur_sent_ind])
            cur_sent_seq = cur_text_seq[
                prev_loc:end_loc]  # get the current sentence seq from a doc

            cur_sent_cp = cur_doc_cp[cur_sent_ind]
            cur_cp_seq.append(int(
                cur_sent_seq[cur_sent_cp]))  # get seq id of Cp

            prev_loc = end_loc
        # end for
        cp_seq_list.append(cur_cp_seq)

    # end for

    return cp_seq_list
    def _pooling_doc_max(self, encoded_docs, mask, target_dim):
        mask = utils.cast_type(mask, FLOAT, self.use_gpu)

        mask = ((mask - 1) * 999).unsqueeze(target_dim + 1).repeat(1, 1, encoded_docs.size(target_dim + 1))
        encoded_docs = encoded_docs + mask
        encoded_docs = encoded_docs.max(dim=target_dim)[0]  # Batch * sent * dim

        return encoded_docs, mask
Пример #4
0
    def eval_measure_mse(self):
        #
        cur_pred_list = torch.cat(self.pred_list, dim=0)
        reverted_pred = self._convert_to_origin_scale(cur_pred_list)
        reverted_pred = torch.round(reverted_pred)
        predicted = utils.cast_type(reverted_pred, LONG, self.use_gpu)

        cur_label_list = [label for row in self.label_list for label in row]
        cur_label_list = torch.LongTensor(cur_label_list)
        cur_label_list = utils.cast_type(cur_label_list, LONG, self.use_gpu)
        cur_label_list = cur_label_list.view(cur_label_list.shape[0], 1)

        correct = (predicted == cur_label_list).sum().item()
        total = predicted.size(0)
        accuracy = correct / total

        self.eval_reset()

        return accuracy
Пример #5
0
 def __init__(self, config, vocab=None, key_vocab=None, ignore_vocab=None):
     super(PPLLoss, self).__init__()
     self.weight = None
     self.ignore = None
     if vocab is not None:
         if key_vocab is not None:
             self.logger.info("Use extra cost for key words")
             weight = np.ones(len(vocab))
             for key_w in key_vocab.values():
                 weight[vocab.token2id.get(key_w)] = 10.0
             self.weight = cast_type(torch.from_numpy(weight), FLOAT,
                                     config.use_gpu)
         if ignore_vocab is not None:
             self.logger.info("Use extra vocab for ignore words")
             ignore = np.ones(len(vocab))
             for ignore_w in ignore_vocab.values():
                 ignore[vocab.token2id.get(ignore_w)] = 0.0
             self.ignore = cast_type(torch.from_numpy(ignore), FLOAT,
                                     config.use_gpu)
Пример #6
0
    def __init__(self, padding_idx, config, rev_vocab=None, key_vocab=None):
        super(NLLEntropy, self).__init__()
        self.padding_idx = padding_idx if padding_idx is not None else -100
        self.avg_type = config.avg_type

        if rev_vocab is None or key_vocab is None:
            self.weight = None
        else:
            self.logger.info("Use extra cost for key words")
            weight = np.ones(len(rev_vocab))
            for key in key_vocab:
                weight[rev_vocab[key]] = 10.0
            self.weight = cast_type(torch.from_numpy(weight), FLOAT,
                                    config.use_gpu)
Пример #7
0
    def gumbel_max(self, log_probs):
        """
        Obtain a sample from the Gumbel max. Not this is not differentibale.
        :param log_probs: [batch_size x vocab_size]
        :return: [batch_size x 1] selected token IDs
        """
        sample = torch.Tensor(log_probs.size()).uniform_(0, 1)
        sample = cast_type(Variable(sample), FLOAT, self.use_gpu)

        # compute the gumbel sample
        matrix_u = -1.0 * torch.log(-1.0 * torch.log(sample))
        gumbel_log_probs = log_probs + matrix_u
        max_val, max_ids = torch.max(gumbel_log_probs, dim=-1, keepdim=True)
        return max_ids
Пример #8
0
 def forward(self, s=None, a_seq=None, mode=None, gen_type=None):
     a_seq = cast_type(a_seq, LONG, USE_GPU)
     batch_size = s.shape[0]
     # logging.info("s shape: {}".format(s.shape))
     dec_init_state = self.net(s).unsqueeze(0)
     # logging.info("h: {}, inp: {}".format(dec_init_state.shape, a_seq.shape))
     dec_outs, dec_last, dec_ctx = self.decoder(batch_size,
                                                a_seq[:, 0:-1],
                                                dec_init_state,
                                                mode=mode,
                                                gen_type=gen_type,
                                                beam_size=1)
     labels = a_seq[:, 1:].contiguous()
     enc_dec_nll = self.nll_loss(dec_outs, labels)
     return enc_dec_nll
Пример #9
0
    def gen_positional_input(self, seq_x):
        pos_x = []
        for cur_batch in seq_x:
            cur_pos = []
            for ind, val in enumerate(cur_batch):
                if val != 0:
                    cur_pos.append(ind + 1)
                else:
                    cur_pos.append(0)
            pos_x.append(cur_pos)

        pos_input = torch.LongTensor(pos_x)
        pos_input = utils.cast_type(pos_input, LONG, self.use_gpu)

        return pos_input
    def centering_attn(self, text_inputs, mask_input, len_sents, num_sents,
                       tid):

        ## Parser stage1: determine foward-looking centers and preferred centers
        avg_sents_repr, fwrd_repr, batch_cp_ind = self.get_fwrd_centers(
            text_inputs, mask_input, len_sents)

        ## Parser stage2: decide backward center
        back_repr = self.get_back_centers(avg_sents_repr, fwrd_repr)

        ## Parser stage3: construct hierarchical discourse segments
        batch_segMap = []
        batch_adj_mat = []
        batch_adj_list = []
        batch_root_list = []
        for ind_batch, cur_batch_repr in enumerate(back_repr):
            cur_sent_num = int(num_sents[ind_batch])

            ## Parser stage3-1: get structural information
            seg_map, adj_list, list_root_ds = self.get_disco_seg(
                cur_sent_num, ind_batch, fwrd_repr, cur_batch_repr)

            ## Parser stage3-2: make a tree structure using the information
            cur_tree = self.make_tree_stru(seg_map, adj_list, list_root_ds)

            ## Parser stage3-3: make a numpy array from networkx tree
            cur_adj_mat = np.zeros((self.max_num_sents, self.max_num_sents))
            undir_tree = cur_tree.to_undirected()  # we make an undirected tree
            np_adj_mat = nx.to_numpy_matrix(undir_tree)

            cur_adj_mat[:np_adj_mat.shape[0], :np_adj_mat.
                        shape[1]] = np_adj_mat

            ## store structures for statistical analysis
            batch_adj_mat.append(cur_adj_mat)
            batch_adj_list.append(adj_list)
            batch_root_list.append(list_root_ds)
            batch_segMap.append(list(seg_map.items()))

        # end for ind_batch

        # structural information which will be passed to structure-aware transformer
        adj_mat = torch.from_numpy(np.array(batch_adj_mat))
        adj_mat = utils.cast_type(adj_mat, FLOAT, self.use_gpu)
        batch_cp_ind = batch_cp_ind.tolist()

        return adj_mat, avg_sents_repr, batch_adj_list, batch_root_list, batch_segMap, batch_cp_ind
Пример #11
0
    def sent_repr_avg(self, batch_size, encoder_out, len_sents):
        sent_mask = torch.sign(len_sents)  # (batch_size, len_sent)
        num_sents = sent_mask.sum(dim=1)  # (batch_size)

        sent_repr = torch.zeros(batch_size, self.max_num_sents,
                                self.encoder_coh.encoder_out_size)
        sent_repr = utils.cast_type(sent_repr, FLOAT, self.use_gpu)
        for cur_ind_doc in range(batch_size):
            list_sent_len = len_sents[cur_ind_doc]
            cur_sent_num = int(num_sents[cur_ind_doc])
            cur_loc_sent = 0
            list_cur_doc_sents = []

            for cur_ind_sent in range(cur_sent_num):
                cur_sent_len = int(list_sent_len[cur_ind_sent])
                # cur_local_words = local_output_words[cur_batch, cur_ind_sent:end_sent, :]

                # cur_sent_repr = encoder_out[cur_ind_doc, cur_loc_sent+cur_sent_len-1, :]  # pick the last representation of each sentence
                cur_sent_repr = torch.div(
                    torch.sum(
                        encoder_out[cur_ind_doc,
                                    cur_loc_sent:cur_loc_sent + cur_sent_len],
                        dim=0), cur_sent_len)  # avg version
                cur_sent_repr = cur_sent_repr.view(
                    1, 1, -1)  # restore to (1, 1, xrnn_cell_size)

                list_cur_doc_sents.append(cur_sent_repr)
                cur_loc_sent = cur_loc_sent + cur_sent_len

            # end for cur_len_sent

            cur_sents_repr = torch.stack(
                list_cur_doc_sents,
                dim=1)  # (batch_size, num_sents, rnn_cell_size)
            cur_sents_repr = cur_sents_repr.squeeze(
                2)  # not needed when the last repr is used

            sent_repr[cur_ind_doc, :cur_sent_num, :] = cur_sents_repr
        # end for cur_doc

        return sent_repr
    def sent_repr_avg(self, batch_size, encoder_out, len_sents):
        """return sentence representation by averaging of all words."""

        mask_sent = torch.sign(len_sents)  # (batch_size, len_sent)
        num_sents = mask_sent.sum(dim=1)  # (batch_size)

        sent_repr = torch.zeros(batch_size, self.max_num_sents,
                                self.base_encoder.encoder_out_size)
        sent_repr = utils.cast_type(sent_repr, FLOAT, self.use_gpu)
        for cur_ind_doc in range(batch_size):
            list_sent_len = len_sents[cur_ind_doc]
            cur_sent_num = int(num_sents[cur_ind_doc])
            cur_loc_sent = 0
            list_cur_doc_sents = []

            for cur_ind_sent in range(cur_sent_num):
                cur_sent_len = int(list_sent_len[cur_ind_sent])

                cur_sent_repr = torch.div(
                    torch.sum(
                        encoder_out[cur_ind_doc,
                                    cur_loc_sent:cur_loc_sent + cur_sent_len],
                        dim=0), cur_sent_len)  # avg version
                cur_sent_repr = cur_sent_repr.view(
                    1, 1, -1)  # restore to (1, 1, xrnn_cell_size)

                list_cur_doc_sents.append(cur_sent_repr)
                cur_loc_sent = cur_loc_sent + cur_sent_len

            # end for cur_len_sent

            cur_sents_repr = torch.stack(
                list_cur_doc_sents,
                dim=1)  # (batch_size, num_sents, rnn_cell_size)
            cur_sents_repr = cur_sents_repr.squeeze(
                2)  # not needed when the last repr is used

            sent_repr[cur_ind_doc, :cur_sent_num, :] = cur_sents_repr
        # end for cur_doc

        return sent_repr
Пример #13
0
 def forward(self,
             logits,
             temperature,
             use_gpu,
             hard=False,
             return_max_id=False):
     """
     :param logits: [batch_size, n_class] unnormalized log-prob
     :param temperature: non-negative scalar
     :param hard: if True take argmax
     :return: [batch_size, n_class] sample from gumbel softmax
     """
     y = self.gumbel_softmax_sample(logits, temperature, use_gpu)
     _, y_hard = torch.max(y, dim=-1, keepdim=True)
     if hard:
         y_onehot = cast_type(torch.zeros(y.size()), FLOAT, use_gpu)
         y_onehot.scatter_(-1, y_hard, 1.0)
         y = y_onehot
     if return_max_id:
         return y, y_hard
     else:
         return y
Пример #14
0
    def forward(self,
                text_inputs,
                mask_input,
                len_seq,
                len_sents,
                tid,
                mode=""):

        # #
        if self.pad_level == "sent" or self.pad_level == "sentence":
            # text_inputs = text_inputs.view(self.batch_size, self.max_num_sents*self.max_len_sent)
            text_inputs = text_inputs.view(
                text_inputs.size(0),
                text_inputs.size(1) * text_inputs.size(2))

        #
        encoder_out = self.base_encoder(text_inputs, mask_input, len_seq)

        ## averaging RNN output by their length; named implicit lexical cohesion vector
        len_seq = utils.cast_type(len_seq, FLOAT, self.use_gpu)
        ilc_vec = torch.div(
            torch.sum(encoder_out, dim=1),
            len_seq.unsqueeze(1))  # (batch_size, rnn_cell_size)

        #### Fully Connected
        fc_out = self.linear_1(ilc_vec)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_2(fc_out)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_out(fc_out)

        if self.corpus_target.lower() == "asap":
            fc_out = self.sigmoid(fc_out)

        return fc_out
    def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, len_para=None, list_rels=None, mode=""):

        #### word level representations
        encoder_out = self.encoder_base(text_inputs, mask_input, len_seq)

        #### make sentence representations
        batch_size = text_inputs.size(0)

        sent_mask = torch.sign(len_sents)  # (batch_size, len_sent)
        sent_mask = utils.cast_type(sent_mask, FLOAT, self.use_gpu)
        num_sents = sent_mask.sum(dim=1)  # (batch_size)

        ### sentence level representations
        sent_repr = self.sent_repr_avg(batch_size, encoder_out, len_sents)

        # do not consider sentence level attention in this implementation
        encoded_sentences = self.encoder_sent.forward_skip(sent_repr, sent_mask, num_sents)

        # structured attention for document level
        encoded_documents, doc_attention_matrix = self.structure_att(encoded_sentences)  # get structured attn for sent

        # pooling for sentence
        if self.pooling_sent.lower() == "avg":
            encoded_documents, mask = self._pooling_avg(encoded_documents, sent_mask, len_sent_seq)
        elif self.pooling_sent.lower() == "max":
            encoded_documents, mask = self._pooling_doc_max(encoded_documents, sent_mask, 1)

        ## Fully Connected layers
        fc_out = self.linear_out(encoded_documents)

        if self.corpus_target.lower() == "asap":
            fc_out = self.sigmoid(fc_out)

        model_outputs = []
        model_outputs.append(fc_out)

        # return fc_out
        return model_outputs
    def get_back_centers(self, avg_sents_repr, fwrd_repr):
        """ Determine backward-looking centers"""
        batch_size = avg_sents_repr.size(0)
        back_repr = torch.zeros(batch_size, self.max_num_sents, self.topk_back,
                                self.base_encoder.encoder_out_size)
        back_repr = utils.cast_type(back_repr, FLOAT, self.use_gpu)

        for sent_i in range(self.max_num_sents):
            if sent_i == 0 or sent_i == self.max_num_sents - 1:
                # there is no backward center in the first sentence
                continue
            # end if
            else:
                prev_fwrd_repr = fwrd_repr[:, sent_i -
                                           1, :, :]  # (batch_size, topk_fwrd, dim)
                cur_fwrd_repr = fwrd_repr[:,
                                          sent_i, :, :]  # (batch_size, topk_fwrd, dim)
                cur_sent_repr = avg_sents_repr[:,
                                               sent_i, :]  # (batch_size, dim)

                sim_rank = self.sim_cosine_d2(prev_fwrd_repr,
                                              cur_sent_repr.unsqueeze(1))

                max_sim_val, max_sim_ind = torch.max(sim_rank, dim=1)

                idx = max_sim_ind.view(-1, self.topk_back, 1).expand(
                    max_sim_ind.size(0), self.topk_back,
                    self.base_encoder.encoder_out_size)
                cur_back_repr = prev_fwrd_repr.gather(1, idx)

                back_repr[:, sent_i] = cur_back_repr
                # end for topk_i
            # end else
        # end for sent_i

        return back_repr
    def forward(self,
                text_inputs,
                mask_input,
                len_seq,
                len_sents,
                tid,
                mode=""):
        #
        if self.pad_level == "sent" or self.pad_level == "sentence":
            text_inputs = text_inputs.view(
                text_inputs.shape[0], self.max_num_sents * self.max_len_sent)
        mask = mask_input.view(text_inputs.shape)

        #
        encoder_out = self.base_encoder(text_inputs, mask_input, len_seq)

        # applying conv1d after rnn
        avg_pooled = torch.zeros(text_inputs.shape[0], text_inputs.shape[1],
                                 self.conv_output_size)
        avg_pooled = utils.cast_type(avg_pooled, FLOAT, self.use_gpu)
        for cur_batch, cur_tensor in enumerate(encoder_out):
            ## Actual length version
            if self.target_model == "conll17_al":
                cur_seq_len = int(len_seq[cur_batch])
                cur_tensor = cur_tensor.unsqueeze(0)
                crop_tensor = cur_tensor.narrow(1, 0, cur_seq_len)
                crop_tensor = crop_tensor.transpose(1, 0)
                cur_tensor = crop_tensor
            ## published version: do not consider actual length
            else:
                cur_tensor = cur_tensor.unsqueeze(1)

            # applying conv
            cur_tensor = self.conv(cur_tensor)
            cur_tensor = self.leak_relu(cur_tensor)
            cur_tensor = self.dropout_layer(cur_tensor)
            # cur_tensor = self.avg_pool_1d(cur_tensor)
            cur_tensor = self.avg_adapt_pool1(cur_tensor)
            cur_tensor = cur_tensor.view(cur_tensor.shape[0],
                                         self.conv_output_size)
            avg_pooled[cur_batch, :cur_tensor.shape[0], :] = cur_tensor

        len_seq = utils.cast_type(len_seq, FLOAT, self.use_gpu)

        ## implement attention by parameters
        context_weight = self.context_weight.unsqueeze(1)
        context_weight = context_weight.expand(text_inputs.shape[0],
                                               self.conv_output_size, 1)
        attn_weight = torch.bmm(avg_pooled, context_weight).squeeze(2)
        attn_weight = self.tanh(attn_weight)
        attn_weight = self.softmax(attn_weight)
        # attention applied
        attn_vec = torch.bmm(avg_pooled.transpose(1, 2),
                             attn_weight.unsqueeze(2))

        ilc_vec = attn_vec.squeeze(2)

        ## implement attention by linear
        #attn_vec = self.attn(encoder_out.view(self.batch_size, -1)).unsqueeze(2)
        #attn_vec = self.softmax(attn_vec)
        #ilc_vec_attn = torch.bmm(encoder_out.transpose(1, 2), attn_vec).squeeze(2)

        ## FC

        # fully connected stage
        fc_out = self.linear_1(ilc_vec)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_2(fc_out)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_out(fc_out)
        if self.corpus_target.lower() == "asap":
            fc_out = self.sigmoid(fc_out)

        return fc_out
Пример #18
0
def validate(model, evaluator, dataset_test, config, loss_func, is_test=False):
    model.eval()
    losses = []

    sampler_test = SequentialSampler(dataset_test) if config.local_rank == -1 else DistributedSampler(dataset_test)
    dataloader_test = DataLoader(dataset_test, sampler=sampler_test, batch_size=config.batch_size)

    adj_list = []
    root_ds_list = []
    seg_map_list = []
    cp_ind_list = []
    num_sents_list = []
    cp_seq_list = []

    tid_list = []
    label_list = []
    for text_inputs, label_y, *remains in dataloader_test:
        mask_input = remains[0]
        len_seq = remains[1]
        len_sents = remains[2]
        tid = remains[3]
        cur_origin_score = remains[-1]  # it might not be origin score when it is not needed for the dataset, then will be ignored

        text_inputs = utils.cast_type(text_inputs, LONG, config.use_gpu)
        mask_input = utils.cast_type(mask_input, FLOAT, config.use_gpu)
        len_seq = utils.cast_type(len_seq, FLOAT, config.use_gpu)

        with torch.no_grad():
            model_outputs = model(text_inputs=text_inputs, mask_input=mask_input, len_seq=len_seq, len_sents=len_sents, tid=tid, mode="") 
            coh_score = model_outputs[0]

            if config.output_size == 1:
                coh_score = coh_score.view(text_inputs.shape[0])
            else:
                coh_score = coh_score.view(text_inputs.shape[0], -1)

            if config.output_size == 1:
                label_y = utils.cast_type(label_y, FLOAT, config.use_gpu)
            else:
                label_y = utils.cast_type(label_y, LONG, config.use_gpu)
            label_y = label_y.view(text_inputs.shape[0])

            if loss_func is not None:
                loss = loss_func(coh_score, label_y)

                losses.append(loss.item())

            evaluator.eval_update(coh_score, label_y, tid, cur_origin_score)

            # for the project of centering transformer
            if config.gen_logs and config.target_model.lower() == "cent_attn":
                batch_adj_list = model_outputs[1]
                batch_root_ds = model_outputs[2]
                batch_seg_map = model_outputs[3]
                batch_cp_ind = model_outputs[4]
                batch_num_sents = model_outputs[5]

                adj_list = adj_list + batch_adj_list
                root_ds_list = root_ds_list + batch_root_ds
                seg_map_list = seg_map_list + batch_seg_map
                cp_ind_list = cp_ind_list + batch_cp_ind
                num_sents_list = num_sents_list + batch_num_sents

                tid_list = tid_list + tid.flatten().tolist()
                label_list = label_list + label_y.flatten().tolist()

                cp_seq_list = update_cp_seq(len_seq, len_sents, config, model_outputs[4], text_inputs, cp_seq_list)

        # end with torch.no_grad()
    # end for batch_num

    eval_measure = evaluator.eval_measure(is_test)
    eval_best_val = None
    if is_test:
        eval_best_val = max(evaluator.eval_history)

    if loss_func is not None:
        valid_loss = sum(losses) / len(losses)
        if is_test:
            logger.info("Total valid loss {}".format(valid_loss))
    else:
        valid_loss = np.inf
    
    if is_test:
        logger.info("{} on Test {}".format(evaluator.eval_type, eval_measure))
        logger.info("Best {} on Test {}".format(evaluator.eval_type, eval_best_val))
    else:
        logger.info("{} on Valid {}".format(evaluator.eval_type, eval_measure))


    valid_itpt = [tid_list, label_list, adj_list, root_ds_list, seg_map_list, cp_seq_list]

    return valid_loss, eval_measure, eval_best_val, valid_itpt
    def get_fwrd_centers(self, text_inputs, mask_input, len_sents):
        """ Determine fowrard-looking centers using an attention matrix in a PLM """
        batch_size = text_inputs.size(0)

        fwrd_repr = torch.zeros(batch_size, self.max_num_sents, self.topk_fwr,
                                self.base_encoder.encoder_out_size)
        fwrd_repr = utils.cast_type(fwrd_repr, FLOAT, self.use_gpu)

        avg_sents_repr = torch.zeros(
            batch_size, self.max_num_sents, self.base_encoder.encoder_out_size
        )  # averaged sents repr in the sent level encoding
        avg_sents_repr = utils.cast_type(avg_sents_repr, FLOAT, self.use_gpu)

        batch_cp_ind = torch.zeros(
            batch_size,
            self.max_num_sents)  # only used for manual analysis later
        batch_cp_ind = utils.cast_type(batch_cp_ind, LONG, self.use_gpu)

        cur_ind = torch.zeros(batch_size, dtype=torch.int64)
        cur_ind = utils.cast_type(cur_ind, LONG, self.use_gpu)
        len_sents = utils.cast_type(len_sents, LONG, self.use_gpu)

        for sent_i in range(self.max_num_sents):
            cur_sent_lens = len_sents[:, sent_i]
            cur_max_len = int(torch.max(cur_sent_lens))

            if cur_max_len > 0:
                cur_sent_ids = torch.zeros(batch_size,
                                           cur_max_len,
                                           dtype=torch.int64)
                cur_sent_ids = utils.cast_type(cur_sent_ids, LONG,
                                               self.use_gpu)
                cur_mask = torch.zeros(batch_size,
                                       cur_max_len,
                                       dtype=torch.int64)
                cur_mask = utils.cast_type(cur_mask, FLOAT, self.use_gpu)

                prev_ind = cur_ind
                cur_ind = cur_ind + cur_sent_lens

                for batch_ind, sent_len in enumerate(cur_sent_lens):
                    cur_loc = cur_ind[batch_ind]
                    prev_loc = prev_ind[batch_ind]
                    cur_sent_ids[batch_ind, :cur_loc -
                                 prev_loc] = text_inputs[batch_ind,
                                                         prev_loc:cur_loc]
                    cur_mask[batch_ind, :cur_loc -
                             prev_loc] = mask_input[batch_ind,
                                                    prev_loc:cur_loc]

                # encode each sentence
                cur_encoded = self.base_encoder(cur_sent_ids, cur_mask,
                                                cur_sent_lens)

                encoded_sent = cur_encoded[
                    0]  # encoded output for the current sent
                attn_sent = cur_encoded[
                    1]  # averaged attention for the current sent (batch, item, item)

                ## filter out: we do not consider special tokens and punctation as a center; <.>, <sep>, and <cls>
                list_diag = []
                for batch_ind, cur_mat in enumerate(
                        attn_sent):  # cur_mat:(item, item)
                    cur_diag = torch.diag(cur_mat, diagonal=0)

                    ## masking as the length of each sentence
                    cur_batch_sent_len = int(
                        cur_sent_lens[batch_ind])  # i th sentence with batch
                    if cur_batch_sent_len > 3:
                        cur_diag[
                            cur_batch_sent_len -
                            3:] = 0  # also remove puntation   # 因为XLNet的SEP CLS 在句子末尾
                    else:
                        cur_diag[cur_batch_sent_len -
                                 2:] = 0  # only remove the special tokens
                    list_diag.append(cur_diag)
                # end for
                attn_diag = torch.stack(
                    list_diag)  # because torch.daig does not support batch

                ## select forward-looking centers by indices
                temp_fwr_centers, fwr_sort_ind = torch.sort(
                    attn_diag, dim=1,
                    descending=True)  # forward centers are selected by attn
                fwr_sort_ind = fwr_sort_ind[:, :self.topk_fwr]
                batch_cp_ind[:,
                             sent_i] = fwr_sort_ind[:,
                                                    0]  # only consider the top-1 item for Cp

                # to handle execeptional case when the sent is shorter than topk
                fwr_centers = torch.zeros(batch_size, self.topk_fwr)
                fwr_centers = utils.cast_type(fwr_centers, LONG, self.use_gpu)
                fwr_centers[:, :fwr_sort_ind.size(1)] = fwr_sort_ind

                selected = encoded_sent.gather(
                    1,
                    fwr_centers.unsqueeze(-1).expand(
                        batch_size, self.topk_fwr,
                        self.base_encoder.encoder_out_size))
                fwrd_repr[:, sent_i, :fwr_centers.size(1)] = selected

                ## make a sentence representation by averaging
                cur_sent_lens = cur_sent_lens + 1e-9  # prevent zero division
                cur_avg_repr = torch.div(torch.sum(encoded_sent, dim=1),
                                         cur_sent_lens.unsqueeze(1))

                avg_sents_repr[:, sent_i] = cur_avg_repr
            # end if
        # end for sent_i

        return avg_sents_repr, fwrd_repr, batch_cp_ind
    def forward(self,
                text_inputs,
                mask_input,
                len_seq,
                len_sents,
                tid,
                len_para=None,
                list_rels=None,
                mode=""):

        batch_size = text_inputs.size(0)

        #### stage1: sentence level representations
        sent_mask = torch.sign(len_sents)  # (batch_size, len_sent)
        sent_mask = utils.cast_type(sent_mask, FLOAT, self.use_gpu)
        num_sents = sent_mask.sum(dim=1)  # (batch_size)

        avg_sents_repr = torch.zeros(
            batch_size, self.max_num_sents, self.base_encoder.encoder_out_size
        )  # averaged sents repr in the sent level encoding
        avg_sents_repr = utils.cast_type(avg_sents_repr, FLOAT, self.use_gpu)

        cur_ind = torch.zeros(batch_size, dtype=torch.int64)
        cur_ind = utils.cast_type(cur_ind, LONG, self.use_gpu)
        len_sents = utils.cast_type(len_sents, LONG, self.use_gpu)
        for sent_i in range(self.max_num_sents):
            cur_sent_lens = len_sents[:, sent_i]
            cur_max_len = int(torch.max(cur_sent_lens))

            if cur_max_len > 0:
                cur_sent_ids = torch.zeros(batch_size,
                                           cur_max_len,
                                           dtype=torch.int64)
                cur_sent_ids = utils.cast_type(cur_sent_ids, LONG,
                                               self.use_gpu)
                cur_mask = torch.zeros(batch_size,
                                       cur_max_len,
                                       dtype=torch.int64)
                cur_mask = utils.cast_type(cur_mask, FLOAT, self.use_gpu)

                prev_ind = cur_ind
                cur_ind = cur_ind + cur_sent_lens

                for batch_ind, sent_len in enumerate(cur_sent_lens):
                    cur_loc = cur_ind[batch_ind]
                    prev_loc = prev_ind[batch_ind]
                    cur_sent_ids[batch_ind, :cur_loc -
                                 prev_loc] = text_inputs[batch_ind,
                                                         prev_loc:cur_loc]
                    cur_mask[batch_ind, :cur_loc -
                             prev_loc] = mask_input[batch_ind,
                                                    prev_loc:cur_loc]

            cur_encoded = self.base_encoder(cur_sent_ids, cur_mask,
                                            cur_sent_lens)

            encoded_sent = cur_encoded[
                0]  # encoded output for the current sent

            cur_sent_lens = cur_sent_lens + 1e-9  # prevent zero division
            cur_avg_repr = torch.div(torch.sum(encoded_sent, dim=1),
                                     cur_sent_lens.unsqueeze(1))

            avg_sents_repr[:, sent_i] = cur_avg_repr

        # encoder sentence
        encoded_sents = avg_sents_repr
        mask_sent = torch.arange(
            self.max_num_sents, device=num_sents.device).expand(
                len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1)
        mask_sent = utils.cast_type(mask_sent, BOOL, self.use_gpu)
        num_sents = utils.cast_type(num_sents, FLOAT, self.use_gpu)

        #### stage2: update sentence representations using the tree transformer
        encoded_sents, break_probs = self.tt_encoder(
            encoded_sents, mask_sent
        )  # ['features'], ['node_order'], ['adjacency_list'], ['edge_order']

        #### stage3: document attention
        context_weight = self.context_weight.expand(encoded_sents.shape[0],
                                                    encoded_sents.shape[2], 1)
        attn_weight = torch.bmm(encoded_sents, context_weight).squeeze(2)
        attn_weight = self.tanh(attn_weight)
        attn_weight = masked_softmax(attn_weight, sent_mask)
        # attention applied
        attn_vec = torch.bmm(encoded_sents.transpose(1, 2),
                             attn_weight.unsqueeze(2))
        ilc_vec = attn_vec.squeeze(2)

        #### FC layer

        fc_out = self.linear_1(ilc_vec)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_2(fc_out)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_out(fc_out)

        if self.output_size == 1:
            fc_out = self.sigmoid(fc_out)

        outputs = []
        outputs.append(fc_out)

        # ### Sentence Ordering Task : sum_score (MSELoss)
        # ### avg_sents_repr : [batch_size, sent_num]
        # order_label_list = []
        # order_score_list = torch.empty(0)
        # for batch_i in range(batch_size):
        #     order_label = 0
        #     order_score = 0
        #     shuffled_sents = torch.randperm(int(num_sents[batch_i].item()))
        #     for sent_i in range(shuffled_sents.shape[0]-1):
        #         sent_embed_1 = avg_sents_repr[batch_i, shuffled_sents[sent_i].item()]
        #         sent_embed_2 = avg_sents_repr[batch_i, shuffled_sents[sent_i+1].item()]
        #         if shuffled_sents[sent_i].item() < shuffled_sents[sent_i+1].item() : order_label += 1
        #         so_fc_out = self.so_linear_1(torch.cat((sent_embed_1, sent_embed_2), dim=0))
        #         # so_fc_out = self.leak_relu(so_fc_out)
        #         # so_fc_out = self.dropout_layer(so_fc_out)

        #         # so_fc_out = self.so_linear_2(so_fc_out)
        #         # so_fc_out = self.leak_relu(so_fc_out)
        #         # so_fc_out = self.dropout_layer(so_fc_out)

        #         # so_fc_out = self.so_linear_out(so_fc_out)

        #         so_fc_out = self.sigmoid(so_fc_out)
        #         # print(so_fc_out)

        #         order_score += so_fc_out
        #     order_label_list.append(float(order_label))
        #     if order_score == 0: order_score = torch.Tensor([0.0]).cuda()
        #     if order_score_list.shape[0] == 0:
        #         order_score_list = order_score
        #     else:
        #         order_score_list = torch.cat((order_score_list, order_score), dim=0)
        # # print(order_label_list) # [6, 8, 8, 6, 3, 8, 8, 7, 5, 6, 7, 12, 5, 4, 11, 4]
        # # print(order_score_list) # tensor([4.6839, 2.7560, 3.9098, 5.5054, 0.7618, 3.9745, 7.4416, 6.4322, 1.1115,2.5257, 2.6214, 6.3680, 1.9489, 3.5597, 6.2616, 1.5653],device='cuda:0', grad_fn=<CatBackward>)

        ### Sentence Ordering Task : list_score (BCELoss)
        ### avg_sents_repr : [batch_size, sent_num]
        order_label_list = []
        order_score_list = []
        for batch_i in range(batch_size):
            order_label = []
            order_score = []
            shuffled_sents = torch.randperm(int(num_sents[batch_i].item()))
            # print("shuffled_sents: ", len(shuffled_sents))
            for sent_i in range(shuffled_sents.shape[0] - 1):
                sent_embed_1 = encoded_sents[batch_i,
                                             shuffled_sents[sent_i].item()]
                sent_embed_2 = encoded_sents[batch_i,
                                             shuffled_sents[sent_i + 1].item()]
                if shuffled_sents[sent_i].item() < shuffled_sents[sent_i +
                                                                  1].item():
                    order_label.append(1.0)
                else:
                    order_label.append(0.0)
                so_fc_out = self.so_linear_1(
                    torch.cat((sent_embed_1, sent_embed_2), dim=0))
                # so_fc_out = self.leak_relu(so_fc_out)
                # so_fc_out = self.dropout_layer(so_fc_out)

                so_fc_out = self.so_linear_2(so_fc_out)
                # so_fc_out = self.leak_relu(so_fc_out)
                # so_fc_out = self.dropout_layer(so_fc_out)

                so_fc_out = self.so_linear_out(so_fc_out)

                so_fc_out = self.sigmoid(so_fc_out)
                # print(so_fc_out)
                order_score.append(so_fc_out)

                # if order_score.shape[0] == 0: order_score = so_fc_out
                # else : order_score = torch.cat((order_score, so_fc_out), dim=0)

            order_label_list.append(order_label)
            order_score_list.append(order_score)
            # print("order_label: ", order_label, sum(order_label))
            # print("order_score: ", order_score)

            # if order_score == 0: order_score = torch.Tensor([0.0]).cuda()
            # if order_score_list.shape[0] == 0:
            #     order_score_list = order_score
            # else:
            #     order_score_list = torch.cat((order_score_list, order_score), dim=0)
        # print(order_label_list) # [6, 8, 8, 6, 3, 8, 8, 7, 5, 6, 7, 12, 5, 4, 11, 4]
        # print(order_score_list) # tensor([4.6839, 2.7560, 3.9098, 5.5054, 0.7618, 3.9745, 7.4416, 6.4322, 1.1115,2.5257, 2.6214, 6.3680, 1.9489, 3.5597, 6.2616, 1.5653],device='cuda:0', grad_fn=<CatBackward>)

        if len(order_label_list) != len(order_score_list):
            order_label_list = []
            order_score_list = []

        outputs.append(order_label_list)
        outputs.append(order_score_list)

        # return fc_out
        return outputs
Пример #21
0
    def forward(self,
                batch_size,
                inputs=None,
                init_state=None,
                attn_context=None,
                mode=TEACH_FORCE,
                gen_type='greedy',
                beam_size=4):

        # sanity checks
        ret_dict = dict()
        h_0 = init_state.squeeze(0).unsqueeze(1).repeat(
            1, self.max_length - 1, 1)
        if mode == GEN:
            inputs = None

        if gen_type != 'beam':
            beam_size = 1

        if inputs is not None:
            decoder_input = inputs
        else:
            # prepare the BOS inputs
            with torch.no_grad():
                bos_var = Variable(torch.LongTensor([self.sos_id]))
            bos_var = cast_type(bos_var, LONG, self.use_gpu)
            decoder_input = bos_var.expand(batch_size * beam_size, 1)
            # logging.info("dec input: {}".format(decoder_input.shape))
            h_0 = init_state.squeeze(0).unsqueeze(1).repeat(beam_size, 1,
                                                            1)  # 12, 1, 100
            # logging.info("h0: {}".format(h_0.shape))

        if mode == GEN and gen_type == 'beam':
            # if beam search, repeat the initial states of the RNN
            if self.rnn_cell is nn.LSTM:
                h, c = init_state
                decoder_hidden = (self.repeat_state(h, batch_size, beam_size),
                                  self.repeat_state(c, batch_size, beam_size))
            else:
                decoder_hidden = self.repeat_state(init_state, batch_size,
                                                   beam_size)
        else:
            decoder_hidden = init_state

        decoder_outputs = []  # a list of logprob
        sequence_symbols = []  # a list word ids
        back_pointers = []  # a list of parent beam ID
        lengths = np.array([self.max_length] * batch_size * beam_size)

        def decode(step, cum_sum, step_output, step_attn):
            decoder_outputs.append(step_output)
            step_output_slice = step_output.squeeze(1)

            if self.use_attention:
                ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn)

            if gen_type == 'greedy':
                symbols = step_output_slice.topk(1)[1]
            elif gen_type == 'sample':
                symbols = self.gumbel_max(step_output_slice)
            elif gen_type == 'beam':
                if step == 0:
                    seq_score = step_output_slice.view(batch_size, -1)
                    seq_score = seq_score[:, 0:self.output_size]
                else:
                    seq_score = cum_sum + step_output_slice
                    seq_score = seq_score.view(batch_size, -1)

                top_v, top_id = seq_score.topk(beam_size)

                back_ptr = top_id.div(self.output_size).view(-1, 1)
                symbols = top_id.fmod(self.output_size).view(-1, 1)
                cum_sum = top_v.view(-1, 1)
                back_pointers.append(back_ptr)
            else:
                raise ValueError("Unsupported decoding mode")

            sequence_symbols.append(symbols)

            eos_batches = symbols.data.eq(self.eos_id)
            if eos_batches.dim() > 0:
                eos_batches = eos_batches.cpu().view(-1).numpy()
                update_idx = ((lengths > di) & eos_batches) != 0
                lengths[update_idx] = len(sequence_symbols)
            return cum_sum, symbols

        # Manual unrolling is used to support random teacher forcing.
        # If teacher_forcing_ratio is True or False instead of a probability,
        # the unrolling can be done in graph
        if mode == TEACH_FORCE:
            decoder_output, decoder_hidden, attn = self.forward_step(
                decoder_input, decoder_hidden, attn_context, h_0)

            # in teach forcing mode, we don't need symbols.
            decoder_outputs = decoder_output

        else:
            # do free running here
            cum_sum = None
            for di in range(self.max_length):
                decoder_output, decoder_hidden, step_attn = self.forward_step(
                    decoder_input, decoder_hidden, attn_context, h_0)

                cum_sum, symbols = decode(di, cum_sum, decoder_output,
                                          step_attn)
                decoder_input = symbols

            decoder_outputs = torch.cat(decoder_outputs, dim=1)

            if gen_type == 'beam':
                # do back tracking here to recover the 1-best according to
                # beam search.
                final_seq_symbols = []
                cum_sum = cum_sum.view(-1, beam_size)
                max_seq_id = cum_sum.topk(1)[1].data.cpu().view(-1).numpy()
                rev_seq_symbols = sequence_symbols[::-1]
                rev_back_ptrs = back_pointers[::-1]

                for symbols, back_ptrs in zip(rev_seq_symbols, rev_back_ptrs):
                    symbol2ds = symbols.view(-1, beam_size)
                    back2ds = back_ptrs.view(-1, beam_size)
                    # logging.info(symbol2ds)
                    selected_symbols = []
                    selected_parents = []
                    for b_id in range(batch_size):
                        selected_parents.append(back2ds[b_id,
                                                        max_seq_id[b_id]])
                        selected_symbols.append(symbol2ds[b_id,
                                                          max_seq_id[b_id]])
                    # logging.info(selected_symbols)
                    final_seq_symbols.append(
                        torch.stack(selected_symbols, 0).unsqueeze(1))
                    max_seq_id = torch.stack(
                        selected_parents).data.cpu().numpy()
                sequence_symbols = final_seq_symbols[::-1]

        # save the decoded sequence symbols and sequence length
        ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols
        ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist()

        return decoder_outputs, decoder_hidden, ret_dict
Пример #22
0
    def forward(self,
                text_inputs,
                mask_input,
                len_seq,
                len_sents,
                tid,
                len_para=None,
                list_rels=None,
                mode=""):
        # mask_input: (batch, max_tokens), len_sents: (batch, max_num_sents)
        batch_size = text_inputs.size(0)

        mask_sent = torch.sign(len_sents)  # (batch_size, len_sent)
        mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu)
        num_sents = mask_sent.sum(dim=1)  # (batch_size)

        #### Stage1 and 2: sentence repr and discourse segments parser
        adj_mat, sent_repr, batch_adj_list, batch_root_list, batch_segMap, batch_cp_ind = self.centering_attn(
            text_inputs, mask_input, len_sents, num_sents, tid)

        # #### doc-level encoding input text (disable this part if GPU memory is not enough)
        # encoder_doc_out = self.base_encoder(text_inputs, mask_input, len_seq)
        # encoded_doc = encoder_doc_out[0]
        # if self.output_attentions:
        #     attn_doc_avg = encoder_doc_out[1]  # averaged mh attentions (batch, item, item)
        # mask_sent = torch.sign(len_sents)  # (batch_size, len_sent)
        # mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu)
        # num_sents = mask_sent.sum(dim=1)  # (batch_size)
        # sent_repr = self.sent_repr_avg(batch_size, encoded_doc, len_sents)

        # torch.set_printoptions(profile="full")
        # print(adj_mat[0])

        #### Stage3: Structure-aware transformer
        mask_sent_tr = torch.arange(
            self.max_num_sents, device=num_sents.device).expand(
                len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1)
        mask_sent_tr = utils.cast_type(mask_sent_tr, BOOL, self.use_gpu)
        encoded_sents, break_probs = self.tt_encoder(
            sent_repr, mask_sent_tr, adj_mat
        )  # ['features'], ['node_order'], ['adjacency_list'], ['edge_order']

        #### Stage4: Document Attention
        context_weight = self.context_weight.expand(encoded_sents.shape[0],
                                                    encoded_sents.shape[2], 1)
        attn_weight = torch.bmm(encoded_sents, context_weight).squeeze(2)
        attn_weight = self.tanh(attn_weight)
        attn_weight = masked_softmax(attn_weight, mask_sent)
        attn_vec = torch.bmm(encoded_sents.transpose(1, 2),
                             attn_weight.unsqueeze(2))
        ilc_vec = attn_vec.squeeze(2)

        #### FC layer
        fc_out = self.linear_1(ilc_vec)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_2(fc_out)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_out(fc_out)

        if self.output_size == 1:
            fc_out = self.sigmoid(fc_out)

        # prepare output to return
        outputs = []

        if self.gen_logs:
            outputs.append(batch_adj_list)
            outputs.append(batch_root_list)
            outputs.append(batch_segMap)
            outputs.append(batch_cp_ind)
            outputs.append(num_sents.tolist())

        ### Sentence Ordering Task : list_score (BCELoss)
        ### avg_sents_repr : [batch_size, sent_num]

        ## Train过程
        order_label_list = []
        order_score_list = []
        order_score_MaxMin_list = []
        for batch_i in range(batch_size):
            order_label = []
            order_score = []
            order_right_score = []
            order_score_MaxMin = []
            shuffled_sents = torch.randperm(int(num_sents[batch_i].item()))
            # print("shuffled_sents: ", len(shuffled_sents))
            for sent_i in range(shuffled_sents.shape[0] - 1):
                sent_embed_1 = encoded_sents[batch_i,
                                             shuffled_sents[sent_i].item()]
                sent_embed_2 = encoded_sents[batch_i,
                                             shuffled_sents[sent_i + 1].item()]
                so_fc_out = self.so_linear_1(
                    torch.cat((sent_embed_1, sent_embed_2), dim=0))
                # so_fc_out = self.leak_relu(so_fc_out)
                # so_fc_out = self.dropout_layer(so_fc_out)

                so_fc_out = self.so_linear_2(so_fc_out)
                # so_fc_out = self.leak_relu(so_fc_out)
                # so_fc_out = self.dropout_layer(so_fc_out)

                so_fc_out = self.so_linear_out(so_fc_out)
                so_fc_out = self.sigmoid(so_fc_out)

                order_score.append(so_fc_out)

                if shuffled_sents[sent_i].item() < shuffled_sents[sent_i +
                                                                  1].item():
                    order_label.append(1.0)
                    order_right_score.append(so_fc_out)
                else:
                    order_label.append(0.0)

            if len(order_right_score) != 0:
                # order_score_MaxMin.append(max(order_right_score))
                # order_score_MaxMin.append(min(order_right_score))
                order_score_MaxMin.append(
                    sum(order_right_score) / len(order_right_score))

            order_label_list.append(order_label)
            order_score_list.append(order_score)
            order_score_MaxMin_list.append(order_score_MaxMin)

        if len(order_label_list) != len(order_score_list):
            order_label_list = []
            order_score_list = []

        if len(order_score_MaxMin_list) != batch_size:
            order_score_MaxMin_list = []

        outputs.append(fc_out)
        outputs.append(order_label_list)
        outputs.append(order_score_list)
        outputs.append(order_score_MaxMin_list)

        # return fc_out
        return outputs

    # end def forward


# end class
Пример #23
0
 def sample_gumbel(self, logits, use_gpu, eps=1e-20):
     u = torch.rand(logits.size())
     sample = -torch.log(-torch.log(u + eps) + eps)
     sample = cast_type(sample, FLOAT, use_gpu)
     return sample
Пример #24
0
    def forward(self,
                text_inputs,
                mask_input,
                len_seq,
                len_sents,
                tid,
                len_para=None,
                list_rels=None,
                mode=""):
        # print(text_inputs)
        batch_size = text_inputs.size(0)

        #### stage1: sentence level representations
        sent_mask = torch.sign(len_sents)  # (batch_size, len_sent)
        sent_mask = utils.cast_type(sent_mask, FLOAT, self.use_gpu)
        num_sents = sent_mask.sum(
            dim=1)  # (batch_size) 通过上面的sign符号函数获得文章中句子数量num_sents

        avg_sents_repr = torch.zeros(
            batch_size, self.max_num_sents, self.base_encoder.encoder_out_size
        )  # averaged sents repr in the sent level encoding
        avg_sents_repr = utils.cast_type(avg_sents_repr, FLOAT, self.use_gpu)

        cur_ind = torch.zeros(batch_size, dtype=torch.int64)
        cur_ind = utils.cast_type(cur_ind, LONG, self.use_gpu)
        len_sents = utils.cast_type(len_sents, LONG, self.use_gpu)
        for sent_i in range(self.max_num_sents):  # max_num_sents 文章中句子数量
            cur_sent_lens = len_sents[:,
                                      sent_i]  # (batch_size) 一个位置上对应的batch中每个文章的句子长度
            cur_max_len = int(torch.max(cur_sent_lens))  # 找到当前位置句子长度最大值

            if cur_max_len > 0:  # 如果最大值>0 即不都是padding
                cur_sent_ids = torch.zeros(batch_size,
                                           cur_max_len,
                                           dtype=torch.int64)
                cur_sent_ids = utils.cast_type(cur_sent_ids, LONG,
                                               self.use_gpu)
                cur_mask = torch.zeros(batch_size,
                                       cur_max_len,
                                       dtype=torch.int64)
                cur_mask = utils.cast_type(cur_mask, FLOAT, self.use_gpu)

                prev_ind = cur_ind  # 两个指针,一个指向当前句子开头
                cur_ind = cur_ind + cur_sent_lens  # 一个指向当前句子结尾

                for batch_ind, sent_len in enumerate(cur_sent_lens):
                    cur_loc = cur_ind[batch_ind]
                    prev_loc = prev_ind[batch_ind]
                    cur_sent_ids[batch_ind, :cur_loc -
                                 prev_loc] = text_inputs[batch_ind,
                                                         prev_loc:cur_loc]
                    cur_mask[batch_ind, :cur_loc -
                             prev_loc] = mask_input[batch_ind,
                                                    prev_loc:cur_loc]

            cur_encoded = self.base_encoder(cur_sent_ids, cur_mask,
                                            cur_sent_lens)

            encoded_sent = cur_encoded[
                0]  # encoded output for the current sent

            cur_sent_lens = cur_sent_lens + 1e-9  # prevent zero division
            cur_avg_repr = torch.div(torch.sum(encoded_sent, dim=1),
                                     cur_sent_lens.unsqueeze(1))

            avg_sents_repr[:, sent_i] = cur_avg_repr

        # encoder sentence
        mask_sent = torch.arange(
            self.max_num_sents, device=num_sents.device).expand(
                len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1)
        mask_sent = utils.cast_type(mask_sent, BOOL, self.use_gpu)
        num_sents = utils.cast_type(num_sents, FLOAT, self.use_gpu)
        encoded_sents = avg_sents_repr

        #### Stage2: Avg
        ilc_vec = torch.div(torch.sum(encoded_sents, dim=1),
                            num_sents.unsqueeze(1))

        #### FC layer
        # fc1 + (leak_relu + dropout) + fc2 + (leak_relu + dropout) + fc3 + sigmoit 三层线性层
        fc_out = self.linear_1(ilc_vec)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_2(fc_out)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_out(fc_out)

        if self.output_size == 1:
            fc_out = self.sigmoid(fc_out)

        outputs = []
        outputs.append(fc_out)

        # return fc_out

        # 增加Multi-task: 依次计算相邻两个句子之间的Score,假设目标是每两个句子间Score为1,目标函数为Score之和与句子数-1之间的MSE
        # print(avg_sents_repr)
        order_score = torch.zeros(batch_size)
        order_score = utils.cast_type(order_score, FLOAT, self.use_gpu)

        for batch_i in range(batch_size):
            batch_score = 0.0
            for sent_i in range(int(num_sents[batch_i].item()) - 1):
                sents_repr_concated = torch.cat(
                    (avg_sents_repr[batch_i, sent_i],
                     avg_sents_repr[batch_i, sent_i + 1]), 0)
                fc_order_out = self.linear_order_1(sents_repr_concated)
                fc_order_out = self.leak_relu(fc_order_out)
                fc_order_out = self.dropout_layer(fc_order_out)

                fc_order_out = self.linear_order_2(fc_order_out)
                fc_order_out = self.leak_relu(fc_order_out)
                fc_order_out = self.dropout_layer(fc_order_out)

                fc_order_out = self.linear_order_out(fc_order_out)

                batch_score += self.sigmoid(fc_order_out)
            order_score[batch_i] = batch_score
            # print(order_score)

        # for sent_i in range(self.max_num_sents-1):
        #     sents_repr_concated = torch.zeros(batch_size, self.base_encoder.encoder_out_size)
        #     sents_repr_concated = utils.cast_type(sents_repr_concated, FLOAT, self.use_gpu)
        #     sents_repr_concated = torch.cat((avg_sents_repr[:, sent_i], avg_sents_repr[:, sent_i+1]), 1)
        #     fc_order_out = self.linear_order_1(sents_repr_concated)
        #     fc_order_out = self.leak_relu(fc_order_out)
        #     fc_order_out = self.dropout_layer(fc_order_out)

        #     fc_order_out = self.linear_order_2(fc_order_out)
        #     fc_order_out = self.leak_relu(fc_order_out)
        #     fc_order_out = self.dropout_layer(fc_order_out)

        #     fc_order_out = self.linear_order_out(fc_order_out)

        #     score = self.sigmoid(fc_order_out)
        #     # print(score.reshape(batch_size))

        #     order_score += score.reshape(batch_size)

        print(num_sents)
        print(order_score)
        outputs.append(order_score)
        outputs.append(num_sents)

        return outputs
    def forward(self, text_inputs, mask_input, len_seq, len_sents, tid, len_para=None, list_rels=None, mode=""):
        # mask_input: (batch, max_tokens), len_sents: (batch, max_num_sents)
        batch_size = text_inputs.size(0)

        mask_sent = torch.sign(len_sents)  # (batch_size, len_sent)
        mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu)
        num_sents = mask_sent.sum(dim=1)  # (batch_size)

        #### Stage1 and 2: sentence repr and discourse segments parser
        adj_mat, sent_repr, batch_adj_list, batch_root_list, batch_segMap, batch_cp_ind = self.centering_attn(text_inputs, mask_input, len_sents, num_sents, tid)

        # #### doc-level encoding input text (disable this part if GPU memory is not enough)
        # encoder_doc_out = self.base_encoder(text_inputs, mask_input, len_seq)
        # encoded_doc = encoder_doc_out[0]
        # if self.output_attentions:
        #     attn_doc_avg = encoder_doc_out[1]  # averaged mh attentions (batch, item, item)
        # mask_sent = torch.sign(len_sents)  # (batch_size, len_sent)
        # mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu)
        # num_sents = mask_sent.sum(dim=1)  # (batch_size)
        # sent_repr = self.sent_repr_avg(batch_size, encoded_doc, len_sents)
        
        # torch.set_printoptions(profile="full")
        # print(adj_mat[0])

        #### Stage3: Structure-aware transformer
        mask_sent_tr = torch.arange(self.max_num_sents, device=num_sents.device).expand(len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1)
        mask_sent_tr = utils.cast_type(mask_sent_tr, BOOL, self.use_gpu)
        encoded_sents, break_probs = self.tt_encoder(sent_repr, mask_sent_tr, adj_mat)  # ['features'], ['node_order'], ['adjacency_list'], ['edge_order']

        #### Stage4: Document Attention
        context_weight = self.context_weight.expand(encoded_sents.shape[0], encoded_sents.shape[2], 1)
        attn_weight = torch.bmm(encoded_sents, context_weight).squeeze(2)
        attn_weight = self.tanh(attn_weight)
        attn_weight = masked_softmax(attn_weight, mask_sent)
        attn_vec = torch.bmm(encoded_sents.transpose(1, 2), attn_weight.unsqueeze(2))
        ilc_vec = attn_vec.squeeze(2)

        #### FC layer
        fc_out = self.linear_1(ilc_vec)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_2(fc_out)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_out(fc_out)

        if self.output_size == 1:
            fc_out = self.sigmoid(fc_out)

        # prepare output to return
        outputs = []

        if self.gen_logs:
            outputs.append(batch_adj_list)
            outputs.append(batch_root_list)
            outputs.append(batch_segMap)
            outputs.append(batch_cp_ind)
            outputs.append(num_sents.tolist())
        

        ### Sentence Ordering Task : list_score (BCELoss)
        ### avg_sents_repr : [batch_size, sent_num]

        ## 如果是predict,则不用shuffle过程,直接得到Sentence Order Score
        # def Gaussian_prob(data, avg, sig):
        #     data = data.cpu()
        #     sqrt_2pi = np.power(2 * np.pi, 0.5)
        #     coef = 1 / (sqrt_2pi * sig)
        #     powercoef = -1 / (2 * np.power(sig, 2))
        #     mypow = powercoef * (np.power((data-avg), 2))
        #     return coef * (np.exp(mypow)).cuda()
            

        # if mode == "predict":
        #     batch_order_score = []
        #     for batch_i in range(batch_size):
        #         order_score = []
        #         for sent_i in range(int(num_sents[batch_i].item())-1):
        #             sent_embed_1 = encoded_sents[batch_i, sent_i]
        #             sent_embed_2 = encoded_sents[batch_i, sent_i+1]
        #             so_fc_out = self.so_linear_1(torch.cat((sent_embed_1, sent_embed_2), dim=0))
        #             so_fc_out = self.so_linear_2(so_fc_out)
        #             so_fc_out = self.so_linear_out(so_fc_out)
        #             so_fc_out = self.sigmoid(so_fc_out)
        #             order_score.append(so_fc_out)
        #         if len(order_score) != 0:
        #             sent_order_score = torch.mean(torch.stack(order_score), dim=0)
        #         else: sent_order_score = torch.Tensor([0.]).cuda()
        #         class_score = torch.cat((Gaussian_prob(sent_order_score, 0, 0.4), torch.cat((Gaussian_prob(sent_order_score, 0.5, 0.4), Gaussian_prob(sent_order_score, 1, 0.4)), dim=0)), dim=0)
        #         batch_order_score.append(class_score)
        #         print("sent_order_score: ", sent_order_score)
        #     batch_order_score = torch.stack((batch_order_score), dim=0)
        #     # print("batch_order_score", batch_order_score)
        #     print("fc_out: ", fc_out)
        #     fc_out += batch_order_score

        #     outputs.append(fc_out)
        #     return outputs


        ## Train过程
        order_label_list = []
        order_score_list = []
        for batch_i in range(batch_size):
            order_label = []
            order_score = []
            shuffled_sents = torch.randperm(int(num_sents[batch_i].item()))
            # print("shuffled_sents: ", len(shuffled_sents))
            for sent_i in range(shuffled_sents.shape[0]-1):
                sent_embed_1 = encoded_sents[batch_i, shuffled_sents[sent_i].item()]
                sent_embed_2 = encoded_sents[batch_i, shuffled_sents[sent_i+1].item()]
                if shuffled_sents[sent_i].item() < shuffled_sents[sent_i+1].item() : order_label.append(1.0)
                else : order_label.append(0.0)
                so_fc_out = self.so_linear_1(torch.cat((sent_embed_1, sent_embed_2), dim=0))
                # so_fc_out = self.leak_relu(so_fc_out)
                # so_fc_out = self.dropout_layer(so_fc_out)

                so_fc_out = self.so_linear_2(so_fc_out)
                # so_fc_out = self.leak_relu(so_fc_out)
                # so_fc_out = self.dropout_layer(so_fc_out)

                so_fc_out = self.so_linear_out(so_fc_out)

                so_fc_out = self.sigmoid(so_fc_out)
                # print(so_fc_out)
                order_score.append(so_fc_out)

                # if order_score.shape[0] == 0: order_score = so_fc_out
                # else : order_score = torch.cat((order_score, so_fc_out), dim=0)

            order_label_list.append(order_label)
            order_score_list.append(order_score)
            # print("order_label: ", order_label, sum(order_label))
            # print("order_score: ", order_score)


            # if order_score == 0: order_score = torch.Tensor([0.0]).cuda()
            # if order_score_list.shape[0] == 0: 
            #     order_score_list = order_score
            # else:
            #     order_score_list = torch.cat((order_score_list, order_score), dim=0)
        # print(order_label_list) # [6, 8, 8, 6, 3, 8, 8, 7, 5, 6, 7, 12, 5, 4, 11, 4]
        # print(order_score_list) # tensor([4.6839, 2.7560, 3.9098, 5.5054, 0.7618, 3.9745, 7.4416, 6.4322, 1.1115,2.5257, 2.6214, 6.3680, 1.9489, 3.5597, 6.2616, 1.5653],device='cuda:0', grad_fn=<CatBackward>)

        if len(order_label_list) != len(order_score_list) :
            order_label_list = []
            order_score_list = []
        
        outputs.append(fc_out)
        outputs.append(order_label_list)
        outputs.append(order_score_list)

        # return fc_out
        return outputs
    # end def forward

# end class
Пример #26
0
	def __init__(self, config, corpus_target):
		logger.info('Loading embeddings from: ' + config.path_pretrained_emb)

		# Init parameters
		self.emb_path = config.path_pretrained_emb
		self.embed_size = config.embed_size
		if config.encoder_type =="transf":
			self.embed_size = config.d_model
		self.embedding_dim = None

		self.embeddings = {}
		self.emb_matrix = None

		self.x_embed = None  # embedding layer will be returned

		self.tokenizer = corpus_target.tokenizer

		# load embedding
		self.pad_id = None
		if self.emb_path.startswith("bert-"):
			# Bert returns embedding layer itself, not matrix such as other pretrained embedding
			self.x_embed = self.load_pretrained_bert() # from bert library
			self.embedding_dim = 768
			# self.pad_id = 0
			self.pad_id = self.tokenizer.sp_model.piece_to_id("<pad>")
		elif self.emb_path.startswith("xlnet-"):
			self.x_embed = None  # xlnet does not use additional pretrained embedding class
			# self.embedding_dim = 768
			# if "large" in self.emb_path:
			# 	self.embedding_dim = 1024
			self.embedding_dim = 0  # depreicated, should be cleanled later (2020.03.14)
			self.pad_id = self.tokenizer.sp_model.piece_to_id("<pad>")
		elif self.emb_path.startswith("t5-"):
			self.x_embed = None  # xlnet does not use additional pretrained embedding class
			self.embedding_dim = 1024
			self.pad_id = self.tokenizer.sp_model.piece_to_id("<pad>")
		elif self.emb_path.startswith("bart-"):
			self.x_embed = None  # xlnet does not use additional pretrained embedding class
			self.embedding_dim = 1024
			self.pad_id = self.tokenizer.sp_model.piece_to_id("<pad>")
		else:
			# manual version
			self.vocab_size = len(corpus_target.vocab)
			self.vocab = corpus_target.vocab  # word2id
			self.rev_vocab = corpus_target.rev_vocab  # id2word
			# self.pad_id = self.rev_vocab[PAD]
			self.pad_id = corpus_target.pad_id

			self.num_special_vocab = corpus_target.num_special_vocab


			if not self.emb_path.lower().startswith("none") and len(self.emb_path) > 1:
				self.load_pretrained_file()  # from pretrained file
			self.emb_matrix = np.zeros((len(self.vocab), self.embed_size))
			self.get_emb_matrix_given_vocab()  # assign emb_matrix -> (len(vocab), embed_size)
			self.x_embed = nn.Embedding(self.vocab_size, self.embed_size, padding_idx=self.pad_id)
			self.x_embed = self.x_embed.from_pretrained(torch.FloatTensor(self.emb_matrix))
			self.x_embed.weight.data[self.pad_id] = 0.0 # zero padding

			# padding_indx is disappeared when we use "from_pretrained" in pytorch 1.0 (bug?)
			self.x_embed.padding_idx = self.pad_id

			self.embedding_dim = self.x_embed.embedding_dim

		if config.use_gpu and self.x_embed is not None:
			self.x_embed = utils.cast_type(self.x_embed, FLOAT, config.use_gpu)


		return
Пример #27
0
 def np2var(self, inputs, dtype):
     if inputs is None:
         return None
     return utils.cast_type(Variable(torch.from_numpy(inputs)), dtype,
                            self.use_gpu)
    def forward(self,
                text_inputs,
                mask_input,
                len_seq,
                len_sents,
                tid,
                len_para=None,
                mode=""):
        batch_size = text_inputs.size(0)

        # #
        if self.pad_level == "sent" or self.pad_level == "sentence":
            text_inputs = text_inputs.view(
                batch_size,
                text_inputs.size(1) * text_inputs.size(2))

        #### word level encoding
        encoder_out = self.base_encoder(text_inputs, mask_input, len_seq)

        #### sentence represntations
        sent_mask = torch.sign(len_sents)  # (batch_size, len_sent)
        num_sents = sent_mask.sum(dim=1)  # (batch_size)

        sent_repr = torch.zeros(batch_size, self.max_num_sents,
                                self.base_encoder.encoder_out_size)
        sent_repr = utils.cast_type(sent_repr, FLOAT, self.use_gpu)
        for cur_ind_doc in range(batch_size):
            list_sent_len = len_sents[cur_ind_doc]
            cur_sent_num = int(num_sents[cur_ind_doc])
            cur_loc_sent = 0
            list_cur_doc_sents = []

            for cur_ind_sent in range(cur_sent_num):
                cur_sent_len = int(list_sent_len[cur_ind_sent])
                cur_sent_repr = torch.div(
                    torch.sum(
                        encoder_out[cur_ind_doc,
                                    cur_loc_sent:cur_loc_sent + cur_sent_len],
                        dim=0), cur_sent_len)  # avg version

                cur_sent_repr = cur_sent_repr.view(
                    1, 1, -1)  # restore to (1, 1, xrnn_cell_size)

                list_cur_doc_sents.append(cur_sent_repr)
                cur_loc_sent = cur_loc_sent + cur_sent_len
            # end for cur_len_sent

            cur_sents_repr = torch.stack(
                list_cur_doc_sents,
                dim=1)  # (batch_size, num_sents, rnn_cell_size)
            cur_sents_repr = cur_sents_repr.squeeze(2)

            sent_repr[cur_ind_doc, :cur_sent_num, :] = cur_sents_repr
        # end for cur_doc

        # encoder sentence
        mask_sent = torch.arange(
            self.max_num_sents, device=num_sents.device).expand(
                len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1)
        mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu)
        num_sents = utils.cast_type(num_sents, FLOAT, self.use_gpu)

        # get averaging
        ilc_vec_sent = torch.div(
            torch.sum(sent_repr, dim=1),
            num_sents.unsqueeze(1))  # (batch_size, rnn_cell_size)
        sim_avg_sent = self.sim_cosine(ilc_vec_sent.unsqueeze(1), sent_repr)
        sim_avg_sent = sim_avg_sent * mask_sent

        # get distance vector
        avg_pooled_sent = torch.zeros(sent_repr.shape[0],
                                      self.size_avg_pool_sent)
        avg_pooled_sent = utils.cast_type(avg_pooled_sent, FLOAT, self.use_gpu)
        for cur_batch, cur_tensor in enumerate(sim_avg_sent):
            cur_seq_len = int(num_sents[cur_batch])
            cur_tensor = cur_tensor.unsqueeze(0)
            crop_tensor = cur_tensor.narrow(1, 0, cur_seq_len)
            crop_tensor = crop_tensor.unsqueeze(1)

            sim_conv = self.conv_sent(
                crop_tensor
            )  ## This part should be tested by sentence level later!
            sim_conv = self.leak_relu(sim_conv)

            sim_conv = self.dropout_01(sim_conv)
            # sim_conv = self.dropout_02(sim_conv)

            cur_avg_pooled = self.max_adapt_pool1_sent(sim_conv)
            avg_pooled_sent[cur_batch, :] = cur_avg_pooled

        #### FC layer
        ilc_vec = torch.cat((ilc_vec_sent, avg_pooled_sent),
                            dim=1)  # concat the centroid and similarity vector

        # test attention part
        fc_out = self.linear_1(ilc_vec)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_2(fc_out)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_out(fc_out)

        if self.output_size == 1:
            fc_out = self.sigmoid(fc_out)

        return fc_out
    def forward(self,
                text_inputs,
                mask_input,
                len_seq,
                len_sents,
                tid,
                len_para=None,
                list_rels=None,
                mode=""):
        # mask_input: (batch, max_tokens), len_sents: (batch, max_num_sents)
        batch_size = text_inputs.size(0)

        mask_sent = torch.sign(len_sents)  # (batch_size, len_sent)
        mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu)
        num_sents = mask_sent.sum(dim=1)  # (batch_size)

        #### Stage1 and 2: sentence repr and discourse segments parser
        adj_mat, sent_repr, batch_adj_list, batch_root_list, batch_segMap, batch_cp_ind = self.centering_attn(
            text_inputs, mask_input, len_sents, num_sents, tid)

        # #### doc-level encoding input text (disable this part if GPU memory is not enough)
        # encoder_doc_out = self.base_encoder(text_inputs, mask_input, len_seq)
        # encoded_doc = encoder_doc_out[0]
        # if self.output_attentions:
        #     attn_doc_avg = encoder_doc_out[1]  # averaged mh attentions (batch, item, item)
        # mask_sent = torch.sign(len_sents)  # (batch_size, len_sent)
        # mask_sent = utils.cast_type(mask_sent, FLOAT, self.use_gpu)
        # num_sents = mask_sent.sum(dim=1)  # (batch_size)
        # sent_repr = self.sent_repr_avg(batch_size, encoded_doc, len_sents)

        # torch.set_printoptions(profile="full")
        # print(adj_mat[0])

        #### Stage3: Structure-aware transformer
        mask_sent_tr = torch.arange(
            self.max_num_sents, device=num_sents.device).expand(
                len(num_sents), self.max_num_sents) < num_sents.unsqueeze(1)
        mask_sent_tr = utils.cast_type(mask_sent_tr, BOOL, self.use_gpu)
        encoded_sents, break_probs = self.tt_encoder(
            sent_repr, mask_sent_tr, adj_mat
        )  # ['features'], ['node_order'], ['adjacency_list'], ['edge_order']

        #### Stage4: Document Attention
        context_weight = self.context_weight.expand(encoded_sents.shape[0],
                                                    encoded_sents.shape[2], 1)
        attn_weight = torch.bmm(encoded_sents, context_weight).squeeze(2)
        attn_weight = self.tanh(attn_weight)
        attn_weight = masked_softmax(attn_weight, mask_sent)
        attn_vec = torch.bmm(encoded_sents.transpose(1, 2),
                             attn_weight.unsqueeze(2))
        ilc_vec = attn_vec.squeeze(2)

        #### FC layer
        fc_out = self.linear_1(ilc_vec)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_2(fc_out)
        fc_out = self.leak_relu(fc_out)
        fc_out = self.dropout_layer(fc_out)

        fc_out = self.linear_out(fc_out)

        if self.output_size == 1:
            fc_out = self.sigmoid(fc_out)

        # prepare output to return
        outputs = []
        outputs.append(fc_out)

        if self.gen_logs:
            outputs.append(batch_adj_list)
            outputs.append(batch_root_list)
            outputs.append(batch_segMap)
            outputs.append(batch_cp_ind)
            outputs.append(num_sents.tolist())

        # return fc_out
        return outputs

    # end def forward


# end class
Пример #30
0
    def init_hidden_layers(self, batch_size, num_layers, rnn_cell_size):
        hid = torch.autograd.Variable(torch.zeros(num_layers, batch_size, rnn_cell_size))
        hid = utils.cast_type(hid, FLOAT, self.use_gpu)
        nn.init.xavier_uniform_(hid)

        return hid