예제 #1
0
    def forward(self,  # type: ignore
                task_index: torch.IntTensor,
                reverse: torch.ByteTensor,
                epoch_trained: torch.IntTensor,
                for_training: torch.ByteTensor,
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                text_id: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        embeddeds = self._encoder(task_index, tokens, epoch_trained, self._valid_discriminator, reverse, for_training,
                                  text_id)
        batch_size = get_batch_size(embeddeds["embedded_text"])

        sentiment_logits = self._sentiment_discriminator(embeddeds["embedded_text"])

        p_domain_logits = self._p_domain_discriminator(embeddeds["private_embedding"])

        # TODO set reverse = true
        s_domain_logits = self._s_domain_discriminator(embeddeds["share_embedding"], reverse=reverse)

        logits = [sentiment_logits, p_domain_logits, s_domain_logits]

        # domain_logits = self._domain_discriminator(embedded_text)
        output_dict = {'logits': sentiment_logits}
        if label is not None:
            loss = self._loss(sentiment_logits, label)
            # task_index = task_index.unsqueeze(0)
            task_index = task_index.expand(batch_size)
            # targets = [label, label, label, task_index, task_index]
            # print(p_domain_logits.shape, task_index, task_index.shape)
            p_domain_loss = self._domain_loss(p_domain_logits, task_index)
            s_domain_loss = self._domain_loss(s_domain_logits, task_index)
            # logger.info("Share domain logits standard variation is {}",
            #             torch.mean(torch.std(F.softmax(s_domain_logits), dim=-1)))
            output_dict["tokens"] = tokens
            output_dict['stm_loss'] = loss
            output_dict['p_d_loss'] = p_domain_loss
            output_dict['s_d_loss'] = s_domain_loss
            # TODO add share domain logits std loss
            output_dict['loss'] = loss + 0.06 * p_domain_loss + 0.04 * s_domain_loss

            for (metric_name, metric) in zip(self.metrics.keys(), self.metrics.values()):
                if "auc" in metric_name:
                    metric(self.decode(output_dict)["label"], label)
                    continue
                metric(sentiment_logits, label)
        print("for training", for_training)
        if not for_training:
            with open("class_probabilities.txt", "a", encoding="utf8") as f:
                f.write(f"Task: {TASKS_NAME[task_index[0].detach()]}\nLine ID: ")
                f.write(" ".join(list(map(str, text_id.cpu().detach().numpy()))))
                f.write("\nProb: ")
                f.write(" ".join(list(map(str, F.softmax(sentiment_logits, dim=-1).cpu().detach().numpy()))))
                f.write("\nLabel: " + " ".join(list(map(str, label.cpu().detach().numpy()))) + "\n")
                f.write("\n\n\n")
        return output_dict
    def forward(
        self,
        tokens_list: Dict[str, torch.LongTensor],
        positions_list: Dict[str, torch.LongTensor],
        sent_positions_list: Dict[str, torch.LongTensor],
        before_loc_start: torch.IntTensor = None,
        before_loc_end: torch.IntTensor = None,
        after_loc_start_list: torch.IntTensor = None,
        after_loc_end_list: torch.IntTensor = None,
        before_category: torch.IntTensor = None,
        after_category_list: torch.IntTensor = None,
        before_category_mask: torch.IntTensor = None,
        after_category_mask_list: torch.IntTensor = None
    ) -> Dict[str, torch.Tensor]:
        """
        :param tokens_list: Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        :param positions_list: same as tokens_list
        :param sent_positions_list: same as tokens_list
        :param before_loc_start: torch.IntTensor = None, required
            An integer ``IndexField`` representation of the before location start
        :param before_loc_end: torch.IntTensor = None, required
            An integer ``IndexField`` representation of the before location end
        :param after_loc_start_list: torch.IntTensor = None, required
            A list of integers ``ListField (IndexField)`` representation of the list of after location starts
            along the sequence of steps
        :param after_loc_end_list: torch.IntTensor = None, required
            A list of integers ``ListField (IndexField)`` representation of the list of after location ends
            along the sequence of steps
        :param before_category: torch.IntTensor = None, required
            An integer ``IndexField`` representation of the before location category
        :param after_category_list: torch.IntTensor = None, required
            A list of integers ``ListField (IndexField)`` representation of the list of after location categories
            along the sequence of steps
        :param before_category_mask: torch.IntTensor = None, required
            An integer ``IndexField`` representation of whether the before location is known or not (0/1)
        :param after_category_mask_list: torch.IntTensor = None, required
            A list of integers ``ListField (IndexField)`` representation of the list of whether after location is
            known or not for each step along the sequence of steps
        :return:
        An output dictionary consisting of:
        best_span: torch.FloatTensor
            A tensor of shape ``()``
        true_span: torch.FloatTensor
        loss: torch.FloatTensor
        """

        # batchsize * listLength * paragraphSize * embeddingSize
        input_embedding_paragraph = self._text_field_embedder(tokens_list)
        input_pos_embedding_paragraph = self._pos_field_embedder(
            positions_list)
        input_sent_pos_embedding_paragraph = self._sent_pos_field_embedder(
            sent_positions_list)
        # batchsize * listLength * paragraphSize * (embeddingSize*2)
        embedding_paragraph = torch.cat([
            input_embedding_paragraph, input_pos_embedding_paragraph,
            input_sent_pos_embedding_paragraph
        ],
                                        dim=-1)

        # batchsize * listLength * paragraphSize,  this mask is shared with the text fields and sequence label fields
        para_mask = util.get_text_field_mask(tokens_list,
                                             num_wrapping_dims=1).float()

        # batchsize * listLength ,  this mask is shared with the index fields
        para_index_mask, para_index_mask_indices = torch.max(para_mask, 2)

        # apply mask to update the index values,  padded instances will be 0
        after_loc_start_list = (after_loc_start_list.float() *
                                para_index_mask.unsqueeze(2)).long()
        after_loc_end_list = (after_loc_end_list.float() *
                              para_index_mask.unsqueeze(2)).long()
        after_category_list = (after_category_list.float() *
                               para_index_mask.unsqueeze(2)).long()
        after_category_mask_list = (after_category_mask_list.float() *
                                    para_index_mask.unsqueeze(2)).long()

        batch_size, list_size, paragraph_size, input_dim = embedding_paragraph.size(
        )

        # to store the values passed to next step
        tmp_category_probability = torch.zeros(batch_size, 3)
        tmp_start_probability = torch.zeros(batch_size, paragraph_size)

        loss = 0

        # store the predict logits for the whole lists
        category_predict_logits_after_list = torch.rand(
            batch_size, list_size, 3)
        best_span_after_list = torch.rand(batch_size, list_size, 2)

        for index in range(list_size):
            # get one slice of step for prediction
            embedding_paragraph_slice = embedding_paragraph[:,
                                                            index, :, :].squeeze(
                                                                1)
            para_mask_slice = para_mask[:, index, :].squeeze(1)
            para_lstm_mask_slice = para_mask_slice if self._mask_lstms else None
            para_index_mask_slice = para_index_mask[:, index]
            after_category_mask_slice = after_category_mask_list[:,
                                                                 index, :].squeeze(
                                                                 )

            # bi-LSTM: generate the contextual embeddings for the current step
            # size: batchsize * paragraph_size * modeling_layer_hidden_size
            encoded_paragraph = self._dropout(
                self._modeling_layer(embedding_paragraph_slice,
                                     para_lstm_mask_slice))

            # max-pooling output for three category classification
            category_input, category_input_indices = torch.max(
                encoded_paragraph, 1)

            modeling_dim = encoded_paragraph.size(-1)
            span_start_input = encoded_paragraph

            # predict the initial before location state
            if index == 0:

                # three category classification for initial before location
                category_predict_logits_before = self._category_before_predictor(
                    category_input)
                tmp_category_probability = category_predict_logits_before
                '''Model the before_loc prediction'''
                # predict the initial before location start scores
                # shape:  batchsize * paragraph_size
                span_start_logits_before = self._span_start_predictor_before(
                    span_start_input).squeeze(-1)
                # shape:  batchsize * paragraph_size
                span_start_probs_before = util.masked_softmax(
                    span_start_logits_before, para_mask_slice)
                tmp_start_probability = span_start_probs_before

                # shape:  batchsize * hiddensize
                span_start_representation_before = util.weighted_sum(
                    encoded_paragraph, span_start_probs_before)

                # Shape: (batch_size, passage_length, modeling_dim)
                tiled_start_representation_before = span_start_representation_before.unsqueeze(
                    1).expand(batch_size, paragraph_size, modeling_dim)

                # incorporate the original contextual embeddings and weighted sum vector from location start prediction
                # shape: batchsize * paragraph_size * 2hiddensize
                span_end_representation_before = torch.cat(
                    [encoded_paragraph, tiled_start_representation_before],
                    dim=-1)
                # Shape: (batch_size, passage_length, encoding_dim)
                encoded_span_end_before = self._dropout(
                    self._span_end_encoder_before(
                        span_end_representation_before, para_lstm_mask_slice))

                # initial before location end prediction
                encoded_span_end_before = torch.cat(
                    [encoded_paragraph, encoded_span_end_before], dim=-1)
                # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
                span_end_logits_before = self._span_end_predictor_before(
                    encoded_span_end_before).squeeze(-1)
                span_end_probs_before = util.masked_softmax(
                    span_end_logits_before, para_mask_slice)

                # best_span_bef = self._get_best_span(span_start_logits_bef, span_end_logits_bef)
                best_span_before, best_span_before_start, best_span_before_end, best_span_before_real = \
                    self._get_best_span_single_extend(span_start_logits_before, span_end_logits_before,
                                                      category_predict_logits_before, before_category_mask)

                # compute the loss for initial bef location three-category classification
                before_null_pred = softmax(category_predict_logits_before)
                before_null_pred_values, before_null_pred_indices = torch.max(
                    before_null_pred, 1)
                loss += nll_loss(before_null_pred, before_category.squeeze(-1))

                # compute the loss for initial bef location start/end prediction
                before_loc_start_pred = util.masked_softmax(
                    span_start_logits_before, para_mask_slice)
                logpy_before_start = torch.gather(
                    before_loc_start_pred, 1,
                    before_loc_start).view(-1).float()
                before_category_mask = before_category_mask.float()
                loss += -(logpy_before_start * before_category_mask).mean()
                before_loc_end_pred = util.masked_softmax(
                    span_end_logits_before, para_mask_slice)
                logpy_before_end = torch.gather(before_loc_end_pred, 1,
                                                before_loc_end).view(-1)
                loss += -(logpy_before_end * before_category_mask).mean()

                # get the real predicted location spans
                # convert category output (Null and Unk) into spans ((-2,-2) or (-1, -1))
                before_loc_start_real = self._get_real_spans_extend(
                    before_loc_start, before_category, before_category_mask)
                before_loc_end_real = self._get_real_spans_extend(
                    before_loc_end, before_category, before_category_mask)
                true_span_before = torch.stack(
                    [before_loc_start_real, before_loc_end_real], dim=-1)
                true_span_before = true_span_before.squeeze(1)

            # input for (after location) three category classification
            category_input_after = torch.cat(
                (category_input, tmp_category_probability), dim=1)
            category_predict_logits_after = self._category_after_predictor(
                category_input_after)
            tmp_category_probability = category_predict_logits_after

            # copy the predict logits for the index of the list
            category_predict_logits_after_tmp = category_predict_logits_after.unsqueeze(
                1)
            category_predict_logits_after_list[:,
                                               index, :] = category_predict_logits_after_tmp.data
            '''  Model the after_loc prediction  '''
            # after location start prediction: takes contextual embeddings and weighted sum vector as input
            # shape:  batchsize * hiddensize
            prev_start = util.weighted_sum(category_input,
                                           tmp_start_probability)
            tiled_prev_start = prev_start.unsqueeze(1).expand(
                batch_size, paragraph_size, modeling_dim)
            span_start_input_after = torch.cat(
                (span_start_input, tiled_prev_start), dim=2)
            encoded_start_input_after = self._dropout(
                self._span_start_encoder_after(span_start_input_after,
                                               para_lstm_mask_slice))
            span_start_input_after_cat = torch.cat(
                [encoded_paragraph, encoded_start_input_after], dim=-1)

            # predict the after location start
            span_start_logits_after = self._span_start_predictor_after(
                span_start_input_after_cat).squeeze(-1)
            # shape:  batchsize * paragraph_size
            span_start_probs_after = util.masked_softmax(
                span_start_logits_after, para_mask_slice)
            tmp_start_probability = span_start_probs_after

            # after location end prediction: takes contextual embeddings and weight sum vector as input
            # shape:  batchsize * hiddensize
            span_start_representation_after = util.weighted_sum(
                encoded_paragraph, span_start_probs_after)
            # Tensor Shape: (batch_size, passage_length, modeling_dim)
            tiled_start_representation_after = span_start_representation_after.unsqueeze(
                1).expand(batch_size, paragraph_size, modeling_dim)
            # shape: batchsize * paragraph_size * 2hiddensize
            span_end_representation_after = torch.cat(
                [encoded_paragraph, tiled_start_representation_after], dim=-1)
            # Tensor Shape: (batch_size, passage_length, encoding_dim)
            encoded_span_end_after = self._dropout(
                self._span_end_encoder_after(span_end_representation_after,
                                             para_lstm_mask_slice))
            encoded_span_end_after = torch.cat(
                [encoded_paragraph, encoded_span_end_after], dim=-1)
            # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
            span_end_logits_after = self._span_end_predictor_after(
                encoded_span_end_after).squeeze(-1)
            span_end_probs_after = util.masked_softmax(span_end_logits_after,
                                                       para_mask_slice)

            # get the best span for after location prediction
            best_span_after, best_span_after_start, best_span_after_end, best_span_after_real = \
                self._get_best_span_single_extend(span_start_logits_after, span_end_logits_after,
                                                  category_predict_logits_after, after_category_mask_slice)

            # copy current best span to the list for final evaluation
            best_span_after_list[:, index, :] = best_span_after.data.view(
                batch_size, 1, 2)
            """ Compute the Loss for this slice """
            after_category_mask = after_category_mask_slice.float().squeeze(
                -1)  # batchsize
            after_category_slice = after_category_list[:,
                                                       index, :]  # batchsize * 1
            after_loc_start_slice = after_loc_start_list[:, index, :]
            after_loc_end_slice = after_loc_end_list[:, index, :]

            # compute the loss for (after location) three category classification
            para_index_mask_slice_tiled = para_index_mask_slice.unsqueeze(
                1).expand(para_index_mask_slice.size(0), 3)
            after_category_pred = util.masked_softmax(
                category_predict_logits_after, para_index_mask_slice_tiled)
            logpy_after_category = torch.gather(after_category_pred, 1,
                                                after_category_slice).view(-1)
            loss += -(logpy_after_category * para_index_mask_slice).mean()

            # compute the loss for location start/end prediction
            after_loc_start_pred = util.masked_softmax(span_start_logits_after,
                                                       para_mask_slice)
            logpy_after_start = torch.gather(after_loc_start_pred, 1,
                                             after_loc_start_slice).view(-1)
            loss += -(logpy_after_start * after_category_mask).mean()
            after_loc_end_pred = util.masked_softmax(span_end_logits_after,
                                                     para_mask_slice)
            logpy_after_end = torch.gather(after_loc_end_pred, 1,
                                           after_loc_end_slice).view(-1)
            loss += -(logpy_after_end * after_category_mask).mean()

        # for evaluation  (combine the all annotations)
        after_loc_start_real = self._get_real_spans_extend_list(
            after_loc_start_list, after_category_list,
            after_category_mask_list)
        after_loc_end_real = self._get_real_spans_extend_list(
            after_loc_end_list, after_category_list, after_category_mask_list)

        true_span_after = torch.stack(
            [after_loc_start_real, after_loc_end_real], dim=-1)
        true_span_after = true_span_after.squeeze(2)
        best_span_after_list = Variable(best_span_after_list)

        true_span_after = true_span_after.view(
            true_span_after.size(0) * true_span_after.size(1),
            true_span_after.size(2)).float()

        para_index_mask_tiled = para_index_mask.view(-1, 1)
        para_index_mask_tiled = para_index_mask_tiled.expand(
            para_index_mask_tiled.size(0), 2)

        para_index_mask_tiled2 = para_index_mask.unsqueeze(2).expand(
            para_index_mask.size(0), para_index_mask.size(1), 2)
        after_category_mask_list_tiled = after_category_mask_list.expand(
            batch_size, list_size, 2)
        after_category_mask_list_tiled = after_category_mask_list_tiled * para_index_mask_tiled2.long(
        )

        # merge all the best spans predicted for the current batch, filter out the padded instances
        merged_sys_span, merged_gold_span = self._get_merged_spans(
            true_span_before, best_span_before, true_span_after,
            best_span_after_list, para_index_mask_tiled)

        output_dict = {}
        output_dict["best_span"] = merged_sys_span.view(
            1,
            merged_sys_span.size(0) * merged_sys_span.size(1))
        output_dict["true_span"] = merged_gold_span.view(
            1,
            merged_gold_span.size(0) * merged_gold_span.size(1))
        output_dict["loss"] = loss
        return output_dict
예제 #3
0
    def forward(self,  # type: ignore
                task_index: torch.IntTensor,
                reverse: torch.ByteTensor,
                for_training: torch.ByteTensor,
                train_stage: torch.IntTensor,
                tokens: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        :param task_index:
        :param reverse:
        :param for_training:
        :param train_stage: ["share_senti", "share_classify",
        "share_classify_adversarial", "domain_valid", "domain_valid_adversarial"]
        :param tokens:
        :param label:
        :return:
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()
        embed_tokens = self._encoder(embedded_text, mask)
        batch_size = get_batch_size(embed_tokens)
        # bs * (25*4)
        seq_vec = self._seq_vec(embed_tokens, mask)
        # TODO add linear layer

        domain_embeddings = self._domain_embeddings(torch.arange(self._de_dim).cuda())

        de_scores = F.softmax(
            self._de_attention(seq_vec, domain_embeddings.expand(batch_size, *domain_embeddings.size())), dim=1)
        de_valid = False
        if np.random.rand() < 0.3:
            de_valid = True
            noise = 0.01 * torch.normal(mean=0.5,
                                        # std=torch.std(domain_embeddings).sign_())
                                        std=torch.empty(*de_scores.size()).fill_(1.0))
            de_scores = de_scores + noise.cuda()
        domain_embedding = torch.matmul(de_scores, domain_embeddings)
        domain_embedding = self._de_feedforward(domain_embedding)
        # train sentiment classify
        if train_stage.cpu() == torch.tensor(0) or not for_training:

            de_representation = torch.tanh(torch.add(domain_embedding, seq_vec))

            sentiment_logits = self._sentiment_discriminator(de_representation)
            if label is not None:
                loss = self._loss(sentiment_logits, label)
                self.metrics["{}_stm_acc".format(TASKS_NAME[task_index.cpu()])](sentiment_logits, label)

        if train_stage.cpu() == torch.tensor(1) or not for_training:
            s_domain_logits = self._s_domain_discriminator(seq_vec, reverse=reverse)
            task_index = task_index.expand(batch_size)
            loss = self._domain_loss(s_domain_logits, task_index)
            self.metrics["s_domain_acc"](s_domain_logits, task_index)

        if train_stage.cpu() == torch.tensor(2) or not for_training:
            valid_logits = self._valid_discriminator(domain_embedding, reverse=reverse)
            valid_label = torch.ones(batch_size).cuda()
            if de_valid:
                valid_label = torch.zeros(batch_size).cuda()
            if self._label_smoothing is not None and self._label_smoothing > 0.0:
                loss = sequence_cross_entropy_with_logits(valid_logits,
                                                          valid_label.unsqueeze(0).cuda(),
                                                          torch.tensor(1).unsqueeze(0).cuda(),
                                                          average="token",
                                                          label_smoothing=self._label_smoothing)
            else:
                loss = self._valid_loss(valid_logits,
                                        torch.zeros(2).scatter_(0, valid_label, torch.tensor(1.0)).cuda())
            self.metrics["valid_acc"](valid_logits, valid_label)
        # TODO add orthogonal loss
        output_dict = {"loss": loss}

        return output_dict
예제 #4
0
    def forward(
            self,  # type: ignore
            task_index: torch.IntTensor,
            reverse: torch.ByteTensor,
            epoch_trained: torch.IntTensor,
            for_training: torch.ByteTensor,
            tokens: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        embeddeds = self._encoder(task_index, tokens, epoch_trained,
                                  self._valid_discriminator, reverse,
                                  for_training)
        batch_size = get_batch_size(embeddeds["embedded_text"])

        sentiment_logits = self._sentiment_discriminator(
            embeddeds["embedded_text"])

        p_domain_logits = self._p_domain_discriminator(
            embeddeds["private_embedding"])

        # TODO set reverse = true
        s_domain_logits = self._s_domain_discriminator(
            embeddeds["share_embedding"], reverse=reverse)
        # TODO set reverse = true
        # TODO use share_embedding instead of domain_embedding
        valid_logits = self._valid_discriminator(embeddeds["domain_embedding"],
                                                 reverse=reverse)

        valid_label = embeddeds['valid']

        logits = [
            sentiment_logits, p_domain_logits, s_domain_logits, valid_logits
        ]

        # domain_logits = self._domain_discriminator(embedded_text)
        output_dict = {'logits': sentiment_logits}
        if label is not None:
            loss = self._loss(sentiment_logits, label)
            # task_index = task_index.unsqueeze(0)
            task_index = task_index.expand(batch_size)
            targets = [label, task_index, task_index, valid_label]
            # print(p_domain_logits.shape, task_index, task_index.shape)
            p_domain_loss = self._domain_loss(p_domain_logits, task_index)
            s_domain_loss = self._domain_loss(s_domain_logits, task_index)
            logger.info(
                "Share domain logits standard variation is {}",
                torch.mean(torch.std(F.softmax(s_domain_logits), dim=-1)))
            if self._label_smoothing is not None and self._label_smoothing > 0.0:
                valid_loss = sequence_cross_entropy_with_logits(
                    valid_logits,
                    valid_label.unsqueeze(0).cuda(),
                    torch.tensor(1).unsqueeze(0).cuda(),
                    average="token",
                    label_smoothing=self._label_smoothing)
            else:
                valid_loss = self._valid_loss(
                    valid_logits,
                    torch.zeros(2).scatter_(0, valid_label,
                                            torch.tensor(1.0)).cuda())
            output_dict['stm_loss'] = loss
            output_dict['p_d_loss'] = p_domain_loss
            output_dict['s_d_loss'] = s_domain_loss
            output_dict['valid_loss'] = valid_loss
            # TODO add share domain logits std loss
            output_dict['loss'] = loss + p_domain_loss + 0.005 * s_domain_loss\
                                  # + 0.005 * valid_loss

            # + torch.mean(torch.std(s_domain_logits, dim=1))
            # output_dict['loss'] = loss + p_domain_loss + 0.005 * s_domain_loss

            for (metric, logit, target) in zip(self.metrics.values(), logits,
                                               targets):
                metric(logit, target)

        return output_dict