Beispiel #1
0
class BiLSTM_CRF(nn.Module):
    def __init__(self, args):
        super(BiLSTM_CRF, self).__init__()

        self.name = args.name
        self.hidden_size = args.hidden_size
        self.num_tags = args.num_tags
        self.embedding = nn.Embedding(args.embed_size, args.embed_dim)

        self.crf = ConditionalRandomField(self.num_tags, args.condtraints)
        self.lstm = nn.LSTM(input_size=args.embed_dim,
                            hidden_size=args.hidden_size // 2,
                            num_layers=1,
                            bidirectional=True)
        self.linear = nn.Linear(self.hidden_size, self.num_tags)

        self.device = args.device
        self.dropout = nn.Dropout(args.dropout)

    def get_logits(self, sequences):
        batch_size = sequences.shape[0]
        sequences = sequences.transpose(0, 1)

        embeded = self.embedding(
            sequences)  # (sequence_len, batch_size, embedding_size)

        h0 = torch.randn(2,
                         batch_size,
                         self.hidden_size // 2,
                         device=sequences.device)
        c0 = torch.randn(2,
                         batch_size,
                         self.hidden_size // 2,
                         device=sequences.device)

        outputs, _ = self.lstm(embeded, (h0, c0))

        outputs = self.dropout(outputs)

        outputs = outputs.transpose(
            0, 1)  # (batch_size, sequence_len, hidden_size)

        logits = self.linear(outputs)

        return logits

    def forward(self, sequences: torch.Tensor, tags: torch.Tensor,
                mask) -> torch.Tensor:
        logits = self.get_logits(sequences)
        log_likelihood = self.crf(logits, tags, mask)
        loss = -log_likelihood
        return loss

    def predict(self, sequences, mask):
        logits = self.get_logits(sequences)
        best_path = self.crf.viterbi_tags(logits, mask)
        tags_pred = [tags for tags, score in best_path]
        return tags_pred
Beispiel #2
0
class gru_crf(nn.Module):
    def __init__(self, num_input_features: '(int) number of input features', hidden_size: '(int) number of\
    hidden features the outputs will also have hidden_size features', num_layers: '(int) number of \
    recursion', dropout_gru, bidirectional: '(bool) if True, use bidirectional GRU',\
    tags: "(dict[int: str])example: {0:'I', 1:'B', 2:'O', 3:'<PAD>'}", dropout_FCN: '(double)', drop_GRU_out):
        super().__init__()
        self.gru = nn.GRU(input_size = num_input_features, hidden_size = hidden_size, \
                                  num_layers = num_layers, batch_first = True, dropout = dropout_gru, \
                                  bidirectional = bidirectional)
        #self.gru = WeightDropGRU(input_size = num_input_features, hidden_size = hidden_size, \
        #                         num_layers = num_layers, batch_first = True, dropout = dropout_gru, \
        #                         bidirectional = bidirectional, weight_dropout=drop_weight)
        all_transition=allowed_transitions('BIO', tags)
        #self.crf = CRF(num_tags=len(tags), batch_first= True)
        self.linear = nn.Linear(hidden_size*2, hidden_size)
        self.BN = nn.BatchNorm1d(num_layers)
        self.linear2 = nn.Linear(hidden_size, len(tags))
        self.BN2 = nn.BatchNorm1d(num_layers)
        self.crf = ConditionalRandomField(len(tags), all_transition)
        self.dropout = dropout_FCN
        self.drop_GRU_out = drop_GRU_out
        
    def forward(self, samples, target: '(torch.tensor) shape=(...............,)the target tags to be used',\
                mask: 'True for non-pad elements'):
        length = samples[1]
        samples = samples[0]
        batch_size, words, _ = samples.size()
        tmp_t = time()
        #print(samples.size())
        tmp_compute = F.dropout(self.gru(samples)[0], p=self.dropout)
        #print('pass inference gru')
        tmp_compute = tmp_compute.view(batch_size, words, -1)
        #print('pass reshape gru')
#         print(f'total GRU time: {time() - tmp_t}')
        index_to_cut = max(length).item()#get_longest_seq_len(mask)
        #length = torch.mean(length.float()).item()
        ##############################################
        ###cut padding some parts out#################
        #print(tmp_compute.size())
        #tmp_compute = self.dropout(tmp_compute)
        tmp_compute = F.dropout(F.relu(self.BN(self.linear(tmp_compute))), p=self.drop_GRU_out)
        tmp_compute = F.relu(self.BN2(self.linear2(tmp_compute)))
        tmp_compute = F.dropout(tmp_compute[:, :index_to_cut,:],  p=self.dropout)
        target = target[:, :index_to_cut]
        mask = mask[:, :index_to_cut]
        #print(tmp_compute.size())
        nll_loss = self.crf(tmp_compute,target.long(),mask)
#         print(f'total CRF time: {time() - tmp_t}')
        return nll_loss#/length
    def predict(self, samples, mask):
        length = samples[1]
        samples = samples[0]
        batch_size, words, _ = samples.size()
        tmp_t = time()
        tmp_compute = self.gru(samples)[0].view(batch_size, words, -1)
#         print(f'total GRU time: {time() - tmp_t}')
        index_to_cut = max(length).item()#get_longest_seq_len(mask)
        ##############################################
        ###cut padding some parts out#################
        #print(tmp_compute.size())
        
        tmp_compute = F.relu(self.BN(self.linear(tmp_compute)))
        tmp_compute = F.relu(self.BN2(self.linear2(tmp_compute)))
        tmp_compute = tmp_compute[:, :index_to_cut,:]
        mask = mask[:, :index_to_cut]
        #print(tmp_compute.size())
        tmp_t = time()
        tmp_tags = self.crf.viterbi_tags(tmp_compute,mask)
#         print(f'total CRF prediction time: {time() - tmp_t}')
        return tmp_tags
class SeqClassificationModel(Model):
    """
    Question answering model where answers are sentences
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        use_sep: bool = True,
        with_crf: bool = False,
        self_attn: Seq2SeqEncoder = None,
        bert_dropout: float = 0.1,
        sci_sum: bool = False,
        additional_feature_size: int = 0,
    ) -> None:
        super(SeqClassificationModel, self).__init__(vocab)

        self.text_field_embedder = text_field_embedder
        self.vocab = vocab
        self.use_sep = use_sep
        self.with_crf = with_crf
        self.sci_sum = sci_sum
        self.self_attn = self_attn
        self.additional_feature_size = additional_feature_size

        self.dropout = torch.nn.Dropout(p=bert_dropout)

        # define loss
        if self.sci_sum:
            self.loss = torch.nn.MSELoss(
                reduction='none')  # labels are rouge scores
            self.labels_are_scores = True
            self.num_labels = 1
        else:
            self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1,
                                                  reduction='none')
            self.labels_are_scores = False
            self.num_labels = self.vocab.get_vocab_size(namespace='labels')
            # define accuracy metrics
            self.label_accuracy = CategoricalAccuracy()
            self.all_f1_metrics = FBetaMeasure(beta=1.0, average='micro')
            self.label_f1_metrics = {}

            # define F1 metrics per label
            for label_index in range(self.num_labels):
                label_name = self.vocab.get_token_from_index(
                    namespace='labels', index=label_index)
                self.label_f1_metrics[label_name] = F1Measure(label_index)

        encoded_senetence_dim = text_field_embedder._token_embedders[
            'bert'].output_dim

        ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim(
        )
        ff_in_dim += self.additional_feature_size

        self.time_distributed_aggregate_feedforward = TimeDistributed(
            Linear(ff_in_dim, self.num_labels))

        if self.with_crf:
            self.crf = ConditionalRandomField(
                self.num_labels,
                constraints=None,
                include_start_end_transitions=True)

    def forward(
        self,  # type: ignore
        sentences: torch.LongTensor,
        labels: torch.IntTensor = None,
        confidences: torch.Tensor = None,
        additional_features: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        TODO: add description

        Returns
        -------
        An output dictionary consisting of:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # Input: sentences
        # Output: embedded_sentences

        # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
        embedded_sentences = self.text_field_embedder(sentences)
        mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
        batch_size, num_sentences, _, _ = embedded_sentences.size()

        if self.use_sep:
            # The following code collects vectors of the SEP tokens from all the examples in the batch,
            # and arrange them in one list. It does the same for the labels and confidences.
            # TODO: replace 103 with '[SEP]'
            sentences_mask = sentences[
                'bert'] == 103  # mask for all the SEP tokens in the batch
            embedded_sentences = embedded_sentences[
                sentences_mask]  # given batch_size x num_sentences_per_example x sent_len x vector_len
            # returns num_sentences_per_batch x vector_len
            assert embedded_sentences.dim() == 2
            num_sentences = embedded_sentences.shape[0]
            # for the rest of the code in this model to work, think of the data we have as one example
            # with so many sentences and a batch of size 1
            batch_size = 1
            embedded_sentences = embedded_sentences.unsqueeze(dim=0)
            embedded_sentences = self.dropout(embedded_sentences)

            if labels is not None:
                if self.labels_are_scores:
                    labels_mask = labels != 0.0  # mask for all the labels in the batch (no padding)
                else:
                    labels_mask = labels != -1  # mask for all the labels in the batch (no padding)

                labels = labels[
                    labels_mask]  # given batch_size x num_sentences_per_example return num_sentences_per_batch
                assert labels.dim() == 1
                if confidences is not None:
                    confidences = confidences[labels_mask]
                    assert confidences.dim() == 1
                if additional_features is not None:
                    additional_features = additional_features[labels_mask]
                    assert additional_features.dim() == 2

                num_labels = labels.shape[0]
                if num_labels != num_sentences:  # bert truncates long sentences, so some of the SEP tokens might be gone
                    assert num_labels > num_sentences  # but `num_labels` should be at least greater than `num_sentences`
                    logger.warning(
                        f'Found {num_labels} labels but {num_sentences} sentences'
                    )
                    labels = labels[:
                                    num_sentences]  # Ignore some labels. This is ok for training but bad for testing.
                    # We are ignoring this problem for now.
                    # TODO: fix, at least for testing

                # do the same for `confidences`
                if confidences is not None:
                    num_confidences = confidences.shape[0]
                    if num_confidences != num_sentences:
                        assert num_confidences > num_sentences
                        confidences = confidences[:num_sentences]

                # and for `additional_features`
                if additional_features is not None:
                    num_additional_features = additional_features.shape[0]
                    if num_additional_features != num_sentences:
                        assert num_additional_features > num_sentences
                        additional_features = additional_features[:
                                                                  num_sentences]

                # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
                labels = labels.unsqueeze(dim=0)
                if confidences is not None:
                    confidences = confidences.unsqueeze(dim=0)
                if additional_features is not None:
                    additional_features = additional_features.unsqueeze(dim=0)
        else:
            # ['CLS'] token
            embedded_sentences = embedded_sentences[:, :, 0, :]
            embedded_sentences = self.dropout(embedded_sentences)
            batch_size, num_sentences, _ = embedded_sentences.size()
            sent_mask = (mask.sum(dim=2) != 0)
            embedded_sentences = self.self_attn(embedded_sentences, sent_mask)

        if additional_features is not None:
            embedded_sentences = torch.cat(
                (embedded_sentences, additional_features), dim=-1)

        label_logits = self.time_distributed_aggregate_feedforward(
            embedded_sentences)
        # label_logits: batch_size, num_sentences, num_labels

        if self.labels_are_scores:
            label_probs = label_logits
        else:
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {"action_probs": label_probs}

        # =====================================================================

        if self.with_crf:
            # Layer 4 = CRF layer across labels of sentences in an abstract
            mask_sentences = (labels != -1)
            best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
            #
            # # Just get the tags and ignore the score.
            predicted_labels = [x for x, y in best_paths]
            # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")

            label_loss = 0.0
        if labels is not None:
            # Compute cross entropy loss
            flattened_logits = label_logits.view((batch_size * num_sentences),
                                                 self.num_labels)
            flattened_gold = labels.contiguous().view(-1)

            if not self.with_crf:
                label_loss = self.loss(flattened_logits.squeeze(),
                                       flattened_gold)
                if confidences is not None:
                    label_loss = label_loss * confidences.type_as(
                        label_loss).view(-1)
                label_loss = label_loss.mean()
                flattened_probs = torch.softmax(flattened_logits, dim=-1)
            else:
                clamped_labels = torch.clamp(labels, min=0)
                log_likelihood = self.crf(label_logits, clamped_labels,
                                          mask_sentences)
                label_loss = -log_likelihood
                # compute categorical accuracy
                crf_label_probs = label_logits * 0.
                for i, instance_labels in enumerate(predicted_labels):
                    for j, label_id in enumerate(instance_labels):
                        crf_label_probs[i, j, label_id] = 1
                flattened_probs = crf_label_probs.view(
                    (batch_size * num_sentences), self.num_labels)

            if not self.labels_are_scores:
                evaluation_mask = (flattened_gold != -1)
                self.label_accuracy(flattened_probs.float().contiguous(),
                                    flattened_gold.squeeze(-1),
                                    mask=evaluation_mask)

                self.all_f1_metrics(flattened_probs,
                                    flattened_gold,
                                    mask=evaluation_mask)

                # compute F1 per label
                for label_index in range(self.num_labels):
                    label_name = self.vocab.get_token_from_index(
                        namespace='labels', index=label_index)
                    metric = self.label_f1_metrics[label_name]
                    metric(flattened_probs,
                           flattened_gold,
                           mask=evaluation_mask)

        if labels is not None:
            output_dict["loss"] = label_loss
        output_dict['action_logits'] = label_logits
        return output_dict

    def get_metrics(self, reset: bool = False):
        metric_dict = {}

        if not self.labels_are_scores:
            type_accuracy = self.label_accuracy.get_metric(reset)
            metric_dict['acc'] = type_accuracy
            type_f1 = self.all_f1_metrics.get_metric(reset)
            metric_dict['F1'] = type_f1['fscore']

            average_F1 = 0.0
            for name, metric in self.label_f1_metrics.items():
                metric_val = metric.get_metric(reset)
                metric_dict[name + 'F'] = metric_val[2]
                average_F1 += metric_val[2]

            average_F1 /= len(self.label_f1_metrics.items())
            metric_dict['avgF'] = average_F1

        return metric_dict
Beispiel #4
0
class BertMiddleModel(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        seq2seq_encoder: Seq2SeqEncoder,
        feedforward_encoder: Seq2SeqEncoder,
        dropout: float = 0.0,
        use_crf: bool = False,
        pos_weight: float = 1.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ):

        super(BertMiddleModel, self).__init__(vocab, regularizer)
        self._vocabulary = vocab
        self._text_field_embedder = text_field_embedder
        self._seq2seq_encoder = seq2seq_encoder
        self._dropout = torch.nn.Dropout(p=dropout)

        self._feedforward_encoder = feedforward_encoder
        self._classifier_input_dim = feedforward_encoder.get_output_dim()

        self._classification_layer = torch.nn.Linear(
            self._classifier_input_dim, 2)

        self._use_crf = use_crf

        self._pos_weight = torch.Tensor([1 / (1 - pos_weight), 1 / pos_weight])
        self._pos_weight = torch.nn.Parameter(self._pos_weight /
                                              self._pos_weight.min())
        self._pos_weight.requires_grad = False

        if use_crf:
            self._crf = ConditionalRandomField(num_tags=2)

        self._token_prf = F1Measure(1)

        initializer(self)

    def forward(self,
                document,
                query=None,
                rationale=None,
                metadata=None,
                label=None) -> Dict[str, Any]:
        embedded_text = self._text_field_embedder(document)
        mask = util.get_text_field_mask(document).float()

        embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)
        embedded_text = self._dropout(self._feedforward_encoder(embedded_text))

        logits = self._classification_layer(embedded_text)

        if self._use_crf:
            best_paths = self._crf.viterbi_tags(logits, mask=document["mask"])
            best_paths = [b[0] for b in best_paths]
            best_paths = [
                x + [0] * (logits.shape[1] - len(x)) for x in best_paths
            ]
            best_paths = torch.Tensor(best_paths).to(
                logits.device) * document["mask"]
        else:
            best_paths = (logits[:, :, 1] > 0.5).long() * document["mask"]

        output_dict = {}

        output_dict["predicted_rationales"] = best_paths
        output_dict["mask"] = document["mask"]
        output_dict["metadata"] = metadata

        if rationale is not None:
            if self._use_crf:
                output_dict["loss"] = -self._crf(logits, rationale,
                                                 document["mask"])
            else:
                output_dict["loss"] = ((F.cross_entropy(
                    logits.view(-1, logits.shape[-1]),
                    rationale.view(-1),
                    reduction="none",
                    weight=self._pos_weight,
                ) * document["mask"].view(-1)).sum(-1).mean())

            best_paths = best_paths.unsqueeze(-1)
            best_paths = torch.cat([1 - best_paths, best_paths], dim=-1)
            self._token_prf(best_paths, rationale, document["mask"])
        return output_dict

    def extract_rationale(self, output_dict):
        rationales = []
        sentences = [x["tokens"] for x in output_dict["metadata"]]
        predicted_rationales = output_dict["predicted_rationales"].cpu(
        ).data.numpy()
        for path, words in zip(predicted_rationales, sentences):
            path = list(path)[:len(words)]
            words = [x.text for x in words]
            starts, ends = [], []
            path.append(0)
            for i in range(len(words)):
                if path[i - 1:i] == [0, 1]:
                    starts.append(i)
                if path[i - 1:i] == [1, 0]:
                    ends.append(i)

            assert len(starts) == len(ends)
            spans = list(zip(starts, ends))

            rationales.append({
                "document":
                " ".join([w for i, w in zip(path, words) if i == 1]),
                "spans": [{
                    "span": (s, e),
                    "value": 1
                } for s, e in spans],
                "metadata":
                None,
            })

        return rationales

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = self._token_prf.get_metric(reset)
        return dict(zip(["p", "r", "f1"], metrics))

    def decode(self, output_dict):
        rationales = self.extract_rationale(output_dict)
        new_output_dict = {}

        new_output_dict['rationale'] = rationales
        new_output_dict['document'] = [r['document'] for r in rationales]

        if 'query' in output_dict['metadata'][0]:
            output_dict['query'] = [
                m['query'] for m in output_dict['metadata']
            ]

        for m in output_dict["metadata"]:
            if 'convert_tokens_to_instance' in m:
                del m["convert_tokens_to_instance"]

        new_output_dict['label'] = [
            m['label'] for m in output_dict['metadata']
        ]
        new_output_dict['metadata'] = output_dict['metadata']

        return new_output_dict
Beispiel #5
0
class BaseClfHead(nn.Module):
    """ All-in-one Classifier Head for the Basic Language Model """
    def __init__(self,
                 config,
                 lm_model,
                 lm_config,
                 num_lbs=1,
                 mlt_trnsfmr=False,
                 task_params={},
                 binlb={},
                 binlbr={},
                 **kwargs):
        from . import reduction as R
        super(BaseClfHead, self).__init__()
        self.lm_model = lm_model
        self.lm_config = lm_config
        self.input_keys = config.input_keys
        self.maxlen = config.maxlen
        self.lm_loss = kwargs.setdefault(
            'lm_loss', config.lm_loss if hasattr(config, 'lm_loss') else True)
        self.lm_head = self.__lm_head__()
        self.num_lbs = num_lbs
        pdrop = kwargs.setdefault(
            'pdrop', config.pdrop if hasattr(config, 'pdrop') else 0.2)
        self.sample_weights = kwargs.setdefault(
            'sample_weights',
            config.lm_loss if hasattr(config, 'sample_weights') else False)
        self.mlt_trnsfmr = mlt_trnsfmr  # accept multiple streams of inputs, each of which will be input into the transformer
        self.task_type = kwargs.setdefault('task_type', config.task_type)
        self.task_params = task_params

        self.do_norm = kwargs.setdefault(
            'do_norm', config.do_norm if hasattr(config, 'do_norm') else False)
        self.do_extlin = kwargs.setdefault(
            'do_extlin',
            config.do_extlin if hasattr(config, 'do_extlin') else True)
        self.do_lastdrop = kwargs.setdefault(
            'do_lastdrop',
            config.do_lastdrop if hasattr(config, 'do_lastdrop') else True)
        self.dropout = nn.Dropout2d(
            pdrop) if self.task_type == 'nmt' else nn.Dropout(pdrop)
        self.last_dropout = nn.Dropout(pdrop) if self.do_lastdrop else None
        do_crf = kwargs.setdefault(
            'do_crf', config.do_crf if hasattr(config, 'do_crf') else False)
        self.crf = ConditionalRandomField(num_lbs) if do_crf else None
        constraints = kwargs.setdefault(
            'cnstrnts',
            config.cnstrnts.split(',')
            if hasattr(config, 'cnstrnts') and config.cnstrnts else [])
        self.constraints = [
            cnstrnt_cls(**cnstrnt_params)
            for cnstrnt_cls, cnstrnt_params in constraints
        ]
        do_thrshld = kwargs.setdefault(
            'do_thrshld',
            config.do_thrshld if hasattr(config, 'do_thrshld') else False)
        self.thrshlder = R.ThresholdEstimator(
            last_hdim=kwargs['last_hdim']
        ) if do_thrshld and 'last_hdim' in kwargs else None
        self.thrshld = kwargs.setdefault('thrshld', 0.5)

        # Customerized function calling
        self.lm_logit = self._mlt_lm_logit if self.mlt_trnsfmr else self._lm_logit
        self.clf_h = self._clf_h
        self.dim_mulriple = 2 if self.mlt_trnsfmr and self.task_type in [
            'entlmnt', 'sentsim'
        ] and self.task_params.setdefault(
            'sentsim_func', None) is not None and self.task_params[
                'sentsim_func'] == 'concat' else 1  # two or one sentence
        if self.dim_mulriple > 1 and self.task_params.setdefault(
                'concat_strategy', 'normal') == 'diff':
            self.dim_mulriple = 4

        self.kwprop = {}
        self.binlb = binlb
        self.global_binlb = copy.deepcopy(binlb)
        self.binlbr = binlbr
        self.global_binlbr = copy.deepcopy(binlbr)
        for k, v in kwargs.items():
            setattr(self, k, v)
        self.mode = 'clf'
        self.debug = config.verbose if hasattr(config, 'verbose') else False

    def __init_linear__(self):
        raise NotImplementedError

    def __lm_head__(self):
        raise NotImplementedError

    def __default_pooler__(self):
        from . import reduction as R
        return R.MaskedReduction(reduction=None, dim=1)

    def forward(self,
                input_ids,
                *extra_inputs,
                labels=None,
                all_hidden_states=None,
                weights=None,
                embedding_mode=False):
        use_gpu = input_ids[0].is_cuda if type(
            input_ids) is list else input_ids.is_cuda
        if self.sample_weights and len(extra_inputs) > 0:
            sample_weights = extra_inputs[-1]
            extra_inputs = extra_inputs[:-1]
        else:
            sample_weights = None
        extra_inputs_dict = dict(
            zip([x for x in self.input_keys if x != 'input_ids'],
                extra_inputs))
        pool_idx = extra_inputs_dict['attention_mask'].sum(1)
        mask = extra_inputs_dict['attention_mask']
        if self.debug:
            logging.debug(('size of input_ids', [x.size() for x in input_ids]
                           if type(input_ids) is list else input_ids.size()))
            logging.debug(
                ('input_ids',
                 [[','.join(map(str, s.tolist())) for s in x]
                  for x in input_ids[:5]] if type(input_ids) is list else
                 [','.join(map(str, s.tolist())) for s in input_ids[:5]]))
        # Go through the language model
        output_fields = set(['last_hidden_state', 'hidden_states'])
        if self.mlt_trnsfmr and self.task_type in ['entlmnt', 'sentsim']:
            trnsfm_output = [
                self.transformer(input_ids[x], **extra_inputs_dict)
                for x in [0, 1]
            ]
            hidden_states, all_hidden_states = zip(*[[
                trnsfm_output[x][k] if k in trnsfm_output[x] else None
                for k in ['last_hidden_state', 'hidden_states']
            ] for x in [0, 1]])
            hidden_states, all_hidden_states = list(hidden_states), list(
                all_hidden_states)
            extra_outputs = [
                dict([(k, v) for k, v in trnsfm_output[x].items()
                      if k not in output_fields]) for x in [0, 1]
            ]
        else:
            trnsfm_output = self.transformer(input_ids, **extra_inputs_dict)
            hidden_states, all_hidden_states = (
                trnsfm_output[k] if k in trnsfm_output else None
                for k in ['last_hidden_state', 'hidden_states'])
            extra_outputs = dict([(k, v) for k, v in trnsfm_output.items()
                                  if k not in output_fields])
        if self.debug: logging.debug(('after transformer', trnsfm_output[:5]))

        # Calculate language model loss
        if (self.lm_loss):
            lm_logits, lm_target = self.lm_logit(
                input_ids,
                extra_inputs_dict,
                hidden_states,
                all_hidden_states=all_hidden_states,
                extra_outputs=extra_outputs)
            lm_loss_func = nn.CrossEntropyLoss(ignore_index=-1,
                                               reduction='none')
            lm_loss = lm_loss_func(
                lm_logits.contiguous().view(-1, lm_logits.size(-1)),
                lm_target.contiguous().view(-1)).view(input_ids.size(0), -1)
            if sample_weights is not None: lm_loss *= sample_weights
        else:
            lm_loss = None

        # Pooling
        if self.debug:
            logging.debug(
                ('hdstat: ', [x.size() for x in hidden_states]
                 if type(hidden_states) is list else hidden_states.size()))
        clf_h, mask = self.clf_h(hidden_states,
                                 mask,
                                 all_hidden_states=all_hidden_states,
                                 extra_outputs=extra_outputs)
        if self.debug:
            logging.debug(
                ('after clf_h', [x.size()
                                 for x in clf_h] if type(clf_h) is list
                 or type(clf_h) is tuple else clf_h.size()))
        clf_h = self.pool(input_ids,
                          extra_inputs_dict,
                          mask,
                          clf_h,
                          extra_outputs=extra_outputs)
        if self.debug:
            logging.debug(
                ('after pool',
                 [x.size()
                  for x in clf_h] if type(clf_h) is list else clf_h.size()))

        # Other layers
        if self.mlt_trnsfmr and self.task_type in [
                'entlmnt', 'sentsim'
        ] and (self.task_params.setdefault('sentsim_func', None) is not None
               ):  # default sentsim mode of gpt* is mlt_trnsfmr+_mlt_clf_h
            if self.do_norm: clf_h = [self.norm(clf_h[x]) for x in [0, 1]]
            if self.do_drop: clf_h = [self.dropout(clf_h[x]) for x in [0, 1]]
            if self.do_extlin and hasattr(self, 'extlinear'):
                clf_h = [self.extlinear(clf_h[x]) for x in [0, 1]]
            if embedding_mode: return clf_h
            if self.task_params.setdefault('sentsim_func', None) == 'concat':
                if self.task_params.setdefault('concat_strategy',
                                               'normal') == 'reverse':
                    clf_h = (torch.cat(clf_h, dim=-1) +
                             torch.cat(clf_h[::-1], dim=-1))
                elif self.task_params['concat_strategy'] == 'diff':
                    clf_h = torch.cat(
                        clf_h +
                        [torch.abs(clf_h[0] - clf_h[1]), clf_h[0] * clf_h[1]],
                        dim=-1)
                else:
                    clf_h = torch.cat(clf_h, dim=-1)
                clf_logits = self.linear(clf_h) if self.linear else clf_h
            elif self.task_type == 'sentsim':
                clf_logits = clf_h = F.pairwise_distance(
                    self.linear(clf_h[0]), self.linear(
                        clf_h[1]), 2, eps=1e-12) if self.task_params[
                            'sentsim_func'] == 'dist' else F.cosine_similarity(
                                self.linear(clf_h[0]),
                                self.linear(clf_h[1]),
                                dim=1,
                                eps=1e-12)
        else:
            if self.do_norm: clf_h = self.norm(clf_h)
            if self.debug: logging.debug(('before dropout:', clf_h.size()))
            if self.do_drop: clf_h = self.dropout(clf_h)
            if self.do_extlin and hasattr(self, 'extlinear'):
                clf_h = self.extlinear(clf_h)
            if embedding_mode: return clf_h
            if self.debug: logging.debug(('after dropout:', clf_h.size()))
            if self.debug: logging.debug(('linear', self.linear))
            clf_logits = self.linear(
                clf_h.view(-1, self.n_embd) if self.task_type ==
                'nmt' else clf_h)
        if self.debug: logging.debug(('after linear:', clf_logits.size()))
        if self.thrshlder: self.thrshld = self.thrshlder(clf_h)
        if self.do_lastdrop: clf_logits = self.last_dropout(clf_logits)
        if self.debug: logging.debug(('after lastdrop:', clf_logits[:5]))

        if (labels is None):
            if self.crf:
                tag_seq, score = zip(*self.crf.viterbi_tags(
                    clf_logits.view(input_ids.size()[0], -1, self.num_lbs),
                    torch.ones_like(input_ids)))
                tag_seq = torch.tensor(tag_seq).to(
                    'cuda') if use_gpu else torch.tensor(tag_seq)
                if self.debug:
                    logging.debug((tag_seq.min(), tag_seq.max(), score))
                clf_logits = torch.zeros(
                    (*tag_seq.size(),
                     self.num_lbs)).to('cuda') if use_gpu else torch.zeros(
                         (*tag_seq.size(), self.num_lbs))
                clf_logits = clf_logits.scatter(-1, tag_seq.unsqueeze(-1), 1)
                return clf_logits, extra_outputs
            for cnstrnt in self.constraints:
                clf_logits = cnstrnt(clf_logits)
            if (self.mlt_trnsfmr and self.task_type in ['entlmnt', 'sentsim']
                    and self.task_params.setdefault('sentsim_func',
                                                    None) is not None
                    and self.task_params['sentsim_func'] != 'concat'
                    and self.task_params['sentsim_func'] !=
                    self.task_params.setdefault('ymode', 'sim')):
                return 1 - clf_logits.view(-1, self.num_lbs)
            return clf_logits.view(-1, self.num_lbs), extra_outputs
        if self.debug:
            logging.debug(
                ('label max: ', labels.max(), 'label size: ', labels.size()))
        if self.crf:
            clf_loss = -self.crf(
                clf_logits.view(input_ids.size()[0], -1, self.num_lbs),
                pool_idx)
            if sample_weights is not None: clf_loss *= sample_weights
            return clf_loss, lm_loss, extra_outputs
        else:
            for cnstrnt in self.constraints:
                clf_logits = cnstrnt(clf_logits)
        if self.task_type == 'mltc-clf' or (
                self.task_type == 'entlmnt'
                and self.num_lbs > 1) or self.task_type == 'nmt':
            loss_func = nn.CrossEntropyLoss(weight=weights, reduction='none')
            clf_loss = loss_func(clf_logits.view(-1, self.num_lbs),
                                 labels.view(-1))
        elif self.task_type == 'mltl-clf' or (self.task_type == 'entlmnt'
                                              and self.num_lbs == 1):
            loss_func = nn.BCEWithLogitsLoss(
                pos_weight=10 * weights if weights is not None else None,
                reduction='none')
            clf_loss = loss_func(clf_logits.view(-1, self.num_lbs),
                                 labels.view(-1, self.num_lbs).float())
        elif self.task_type == 'sentsim':
            from util import config as C
            if self.debug: logging.debug(('clf logits: ', clf_logits.size()))
            loss_cls = C.RGRSN_LOSS_MAP[self.task_params.setdefault(
                'loss', 'contrastive'
                if self.task_params.setdefault('sentsim_func', None)
                and self.task_params['sentsim_func'] != 'concat' else 'mse')]
            loss_func = loss_cls(
                reduction='none',
                x_mode=C.SIM_FUNC_MAP.setdefault(
                    self.task_params['sentsim_func'], 'dist'),
                y_mode=self.task_params.setdefault('ymode', 'sim')
            ) if self.task_params.setdefault(
                'sentsim_func', None
            ) and self.task_params['sentsim_func'] != 'concat' else (
                loss_cls(reduction='none',
                         x_mode='sim',
                         y_mode=self.task_params.setdefault('ymode', 'sim')) if
                self.task_params['sentsim_func'] == 'concat' else nn.MSELoss(
                    reduction='none'))
            clf_loss = loss_func(clf_logits.view(-1), labels.view(-1))
        if self.thrshlder:
            num_lbs = labels.view(-1, self.num_lbs).sum(1)
            clf_loss = 0.8 * clf_loss + 0.2 * F.mse_loss(
                self.thrshld,
                torch.sigmoid(
                    torch.topk(clf_logits, k=num_lbs.max(), dim=1,
                               sorted=True)[0][:, num_lbs - 1]),
                reduction='mean')
        if sample_weights is not None: clf_loss *= sample_weights
        return clf_loss, lm_loss, extra_outputs

    def pool(self, input_ids, extra_inputs, mask, clf_h, extra_outputs={}):
        if self.task_type == 'nmt':
            if (hasattr(self, 'layer_pooler')):
                clf_h = self.layer_pooler(clf_h)
            else:
                clf_h = clf_h
        else:
            if not hasattr(self, 'pooler'):
                setattr(self, 'pooler', self.__default_pooler__())
            if self.task_type in ['entlmnt', 'sentsim'] and self.mlt_trnsfmr:
                if (hasattr(self, 'layer_pooler')):
                    lyr_h = [[self.pooler(h, mask[x]) for h in clf_h[x]]
                             for x in [0, 1]]
                    clf_h = [self.layer_pooler(lyr_h[x]) for x in [0, 1]]
                else:
                    clf_h = [self.pooler(clf_h[x], mask[x]) for x in [0, 1]]
            else:
                if (hasattr(self, 'layer_pooler')):
                    lyr_h = [self.pooler(h, mask) for h in clf_h]
                    clf_h = self.layer_pooler(lyr_h)
                else:
                    clf_h = self.pooler(clf_h, mask)
        return clf_h

    def _clf_h(self,
               hidden_states,
               mask,
               all_hidden_states=None,
               extra_outputs={}):
        return (hidden_states, torch.stack(mask).max(0)[0]
                ) if type(hidden_states) is list else (hidden_states, mask)

    def _mlt_clf_h(self,
                   hidden_states,
                   mask,
                   all_hidden_states=None,
                   extra_outputs={}):
        return torch.stack(hidden_states).sum(0), torch.stack(pool_idx).max(
            0)[0]

    def transformer(self, input_ids, **extra_inputs):
        return self.lm_model(input_ids=input_ids,
                             **extra_inputs,
                             return_dict=True)

    def _lm_logit(self,
                  input_ids,
                  extra_inputs,
                  hidden_states,
                  all_hidden_states=None,
                  extra_outputs={}):
        lm_h = hidden_states[:, :-1]
        return self.lm_head(lm_h), input_ids[:, 1:]

    def _mlt_lm_logit(self,
                      input_ids,
                      hidden_states,
                      extra_inputs,
                      all_hidden_states=None,
                      extra_outputs={}):
        lm_h = hidden_states[:, :, :-1].contiguous().view(-1, self.n_embd)
        lm_target = input_ids[:, :, 1:].contiguous().view(-1)
        return self.lm_model.lm_head(lm_h), lm_target.view(-1)

    def freeze_lm(self):
        if not hasattr(self, 'lm_model') or self.lm_model is None: return
        for param in self.lm_model.parameters():
            param.requires_grad = False

    def unfreeze_lm(self):
        if not hasattr(self, 'lm_model') or self.lm_model is None: return
        for param in self.lm_model.parameters():
            param.requires_grad = True

    def to(self, *args, **kwargs):
        super(BaseClfHead, self).to(*args, **kwargs)
        self.constraints = [
            cnstrnt.to(*args, **kwargs) for cnstrnt in self.constraints
        ]
        if hasattr(self, 'linears'):
            self.linears = [lnr.to(*args, **kwargs) for lnr in self.linears]
        return self

    def add_linear(self, num_lbs, idx=0):
        self.num_lbs = num_lbs
        self._total_num_lbs = num_lbs if idx == 0 else self._total_num_lbs + num_lbs
        self.linear = self.__init_linear__()
        if not hasattr(self, 'linears'): self.linears = []
        self.linears.append(self.linear)

    def _update_global_binlb(self, binlb):
        if not hasattr(self, 'global_binlb'):
            setattr(self, 'global_binlb', copy.deepcopy(binlb))
        if not hasattr(self, 'global_binlbr'):
            setattr(self, 'global_binlbr',
                    dict([(v, k) for k, v in binlb.items()]))
        new_lbs = [lb for lb in binlb.keys() if lb not in self.global_binlb]
        self.global_binlb.update(
            dict([(k, i) for i, k in zip(
                range(len(self.global_binlb),
                      len(self.global_binlb) + len(new_lbs)), new_lbs)]))
        self.global_binlbr = dict([(v, k)
                                   for k, v in self.global_binlb.items()])

    def reset_global_binlb(self):
        delattr(self, 'global_binlb')
        delattr(self, 'global_binlbr')

    def get_linear(self, binlb, idx=0):
        self.num_lbs = len(binlb)
        self.binlb = binlb
        self.binlbr = dict([(v, k) for k, v in self.binlb.items()])
        self._update_global_binlb(binlb)
        self._total_num_lbs = len(self.global_binlb)
        if not hasattr(self, 'linears'): self.linears = []
        if len(self.linears) <= idx:
            self.linear = self.__init_linear__()
            self.linears.append(self.linear)
            return self.linears[-1]
        else:
            self.linear = self.linears[idx]
            return self.linears[idx]

    def to_siamese(self, from_scratch=False):
        if not hasattr(self, 'clf_task_type') and self.task_type != 'sentsim':
            self.clf_task_type = self.task_type
        self.task_type = 'sentsim'
        if not hasattr(self, 'clf_num_lbs') and self.task_type != 'sentsim':
            self.clf_num_lbs = self.num_lbs
        self.num_lbs = 1
        self.mlt_trnsfmr = True if isinstance(self, GPTClfHead) or (
            isinstance(self, BERTClfHead) and self.task_params.setdefault(
                'sentsim_func', None) is not None) else False
        self.dim_mulriple = 2 if self.task_params.setdefault(
            'sentsim_func', None) == 'concat' else 1
        self.clf_linear = self.linear
        self.linear = self.siamese_linear if hasattr(
            self,
            'siamese_linear') and not from_scratch else self.__init_linear__()
        self.mode = 'siamese'

    def to_clf(self, from_scratch=False):
        self.task_type = self.clf_task_type
        self.num_lbs = self.clf_num_lbs
        if self.mode == 'siamese':
            self.dim_mulriple = 1
            self.siamese_linear = self.linear
        else:
            self.prv_linear = self.linear
        self.linear = self.clf_linear if hasattr(
            self,
            'clf_linear') and not from_scratch else self.__init_linear__()
        self.mode = 'clf'

    def update_params(self, task_params={}, **kwargs):
        self.task_params.update(task_params)
        for k, v in kwargs.items():
            if hasattr(self, k) and type(v) == type(getattr(self, k)):
                if type(v) is dict:
                    getattr(self, k).update(v)
                else:
                    setattr(self, k, v)
class SeqClassificationModel(Model):
    """
    Question answering model where answers are sentences
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        use_sep: bool = True,
        with_crf: bool = False,
        self_attn: Seq2SeqEncoder = None,
        bert_dropout: float = 0.1,
        sci_sum: bool = False,
        additional_feature_size: int = 0,
    ) -> None:
        super(SeqClassificationModel, self).__init__(vocab)

        self.track_embedding_list = []
        self.track_embedding = {}
        self.text_field_embedder = text_field_embedder
        self.vocab = vocab
        self.use_sep = use_sep
        self.with_crf = with_crf
        self.sci_sum = sci_sum
        self.self_attn = self_attn
        self.additional_feature_size = additional_feature_size

        self.dropout = torch.nn.Dropout(p=bert_dropout)

        # define loss
        if self.sci_sum:
            self.loss = torch.nn.MSELoss(
                reduction='none')  # labels are rouge scores
            self.labels_are_scores = True
            self.num_labels = 1
        else:
            self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1,
                                                  reduction='none')
            self.labels_are_scores = False
            self.num_labels = self.vocab.get_vocab_size(namespace='labels')
            # define accuracy metrics
            self.label_accuracy = CategoricalAccuracy()
            self.label_f1_metrics = {}

            # define F1 metrics per label
            for label_index in range(self.num_labels):
                label_name = self.vocab.get_token_from_index(
                    namespace='labels', index=label_index)
                self.label_f1_metrics[label_name] = F1Measure(label_index)

        encoded_senetence_dim = text_field_embedder._token_embedders[
            'bert'].output_dim

        ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim(
        )
        ff_in_dim += self.additional_feature_size

        self.time_distributed_aggregate_feedforward = TimeDistributed(
            Linear(ff_in_dim, self.num_labels))

        if self.with_crf:
            self.crf = ConditionalRandomField(
                self.num_labels,
                constraints=None,
                include_start_end_transitions=True)
        self.track_embedding["init_info"] = {
            "ff_in_dim": ff_in_dim,
            "encoded_sentence_dim": encoded_senetence_dim,
            "sci_sum": self.sci_sum,
            "use_sep": self.use_sep,
            "with_crf": self.with_crf,
            "additional_feature_size": self.additional_feature_size
        }
        self.t_board_writer = SummaryWriter()
        self.t_board_writer.add_graph(self)

    def forward(
        self,  # type: ignore
        sentences: torch.LongTensor,
        labels: torch.IntTensor = None,
        confidences: torch.Tensor = None,
        additional_features: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        TODO: add description

        Returns
        -------
        An output dictionary consisting of:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # Input: sentences
        # Output: embedded_sentences
        print(sentences)
        sentences_conv = {}
        for key, val in sentences_conv.items():
            sentences_conv[key] = val.cpu().data.numpy().tolist()
        self.track_embedding["Transformation_0"] = {
            "sentences": sentences_conv
        }
        # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
        embedded_sentences = self.text_field_embedder(sentences)
        self.track_embedding["Transformation_1"] = {
            "size": list(embedded_sentences.size()),
            "dim": embedded_sentences.dim()
        }

        # Kacper: Basically a padding mask for bert
        mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
        batch_size, num_sentences, _, _ = list(embedded_sentences.size())

        if self.use_sep:
            # The following code collects vectors of the SEP tokens from all the examples in the batch,
            # and arrange them in one list. It does the same for the labels and confidences.
            # TODO: replace 103 with '[SEP]'
            # Kacper: This is an important step where we get SEP tokens to later do sentence classification
            # Kacper: We take a location of SEP tokens from the sentences to get a mask
            sentences_mask = sentences[
                'bert'] == 103  # mask for all the SEP tokens in the batch
            # Kacper: We use this mask to get the respective embeddings from the output layer of bert
            embedded_sentences = embedded_sentences[
                sentences_mask]  # given batch_size x num_sentences_per_example x sent_len x vector_len
            # returns num_sentences_per_batch x vector_len
            self.track_embedding["Transformation_2"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: I dont get it why it became 2 instead of 4? What is the difference between size() and dim()???
            assert embedded_sentences.dim() == 2
            num_sentences = embedded_sentences.shape[0]
            # Kacper: comment below is vague
            # Kacper: I think we batch in one array because we just need to compute a mean loss from all of them
            # for the rest of the code in this model to work, think of the data we have as one example
            # with so many sentences and a batch of size 1
            batch_size = 1
            embedded_sentences = embedded_sentences.unsqueeze(
                dim=0)  # Kacper: We batch all sentences in one array
            self.track_embedding["Transformation_3"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: Dropout layer is between filtered embeddings and linear layer
            embedded_sentences = self.dropout(embedded_sentences)
            self.track_embedding["Transformation_4"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: we provide the labels for training (for each sentence)
            if labels is not None:
                if self.labels_are_scores:
                    labels_mask = labels != 0.0  # mask for all the labels in the batch (no padding)
                else:
                    labels_mask = labels != -1  # mask for all the labels in the batch (no padding)

                labels = labels[
                    labels_mask]  # given batch_size x num_sentences_per_example return num_sentences_per_batch
                assert labels.dim() == 1
                if confidences is not None:
                    confidences = confidences[labels_mask]
                    assert confidences.dim() == 1
                if additional_features is not None:
                    additional_features = additional_features[labels_mask]
                    assert additional_features.dim() == 2

                num_labels = labels.shape[0]
                # Kacper: this might be useful to consider in my code as well
                if num_labels != num_sentences:  # bert truncates long sentences, so some of the SEP tokens might be gone
                    assert num_labels > num_sentences  # but `num_labels` should be at least greater than `num_sentences`
                    logger.warning(
                        f'Found {num_labels} labels but {num_sentences} sentences'
                    )
                    labels = labels[:
                                    num_sentences]  # Ignore some labels. This is ok for training but bad for testing.
                    # We are ignoring this problem for now.
                    # TODO: fix, at least for testing

                # do the same for `confidences`
                if confidences is not None:
                    num_confidences = confidences.shape[0]
                    if num_confidences != num_sentences:
                        assert num_confidences > num_sentences
                        confidences = confidences[:num_sentences]

                # and for `additional_features`
                if additional_features is not None:
                    num_additional_features = additional_features.shape[0]
                    if num_additional_features != num_sentences:
                        assert num_additional_features > num_sentences
                        additional_features = additional_features[:
                                                                  num_sentences]

                # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
                labels = labels.unsqueeze(dim=0)
                if confidences is not None:
                    confidences = confidences.unsqueeze(dim=0)
                if additional_features is not None:
                    additional_features = additional_features.unsqueeze(dim=0)
        else:
            # ['CLS'] token
            # Kacper: this shouldnt be the case for our project
            embedded_sentences = embedded_sentences[:, :, 0, :]
            embedded_sentences = self.dropout(embedded_sentences)
            batch_size, num_sentences, _ = list(embedded_sentences.size())
            sent_mask = (mask.sum(dim=2) != 0)
            embedded_sentences = self.self_attn(embedded_sentences, sent_mask)

        if additional_features is not None:
            embedded_sentences = torch.cat(
                (embedded_sentences, additional_features), dim=-1)

        # Kacper: we unwrap the time dimension of a tensor into the 1st dimension (batch),
        # Kacper: apply a linear layer and wrap the the time dimension back
        # Kacper: I would suspect it is happening only for embeddings related to the [SEP] tokens
        label_logits = self.time_distributed_aggregate_feedforward(
            embedded_sentences)
        # label_logits: batch_size, num_sentences, num_labels
        self.track_embedding["logits"] = {
            "size": list(label_logits.size()),
            "dim": label_logits.dim()
        }
        #print(self.track_embedding)
        self.track_embedding_list.append(deepcopy(self.track_embedding))
        with open(path_json, 'w') as json_out:
            json.dump(self.track_embedding_list, json_out)

        if self.labels_are_scores:
            label_probs = label_logits
        else:
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {"action_probs": label_probs}

        # =====================================================================

        if self.with_crf:
            # Layer 4 = CRF layer across labels of sentences in an abstract
            mask_sentences = (labels != -1)
            best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
            #
            # # Just get the tags and ignore the score.
            predicted_labels = [x for x, y in best_paths]
            # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")

            label_loss = 0.0
        if labels is not None:
            # Compute cross entropy loss
            # Kacper: reshape logits to be of the following shape in view()
            flattened_logits = label_logits.view((batch_size * num_sentences),
                                                 self.num_labels)
            # Make labels to be contiguous in memory, reshape it so it is in a one dimension
            flattened_gold = labels.contiguous().view(
                -1)  # Kacper: True labels

            if not self.with_crf:
                # Kacper: We are only interested in this part of the code since we don't use crf
                # Kacper: Get a loss (MSE if sci_sum is True or Crossentropy)
                label_loss = self.loss(flattened_logits.squeeze(),
                                       flattened_gold)
                if confidences is not None:
                    label_loss = label_loss * confidences.type_as(
                        label_loss).view(-1)
                label_loss = label_loss.mean()  # Kacper: Get a mean loss
                # Kacper: Get a probabilities from the logits
                flattened_probs = torch.softmax(flattened_logits, dim=-1)
            else:
                # Kacper: We are not interested in this if statement branch (for our project)
                clamped_labels = torch.clamp(labels, min=0)
                log_likelihood = self.crf(label_logits, clamped_labels,
                                          mask_sentences)
                label_loss = -log_likelihood
                # compute categorical accuracy
                crf_label_probs = label_logits * 0.
                for i, instance_labels in enumerate(predicted_labels):
                    for j, label_id in enumerate(instance_labels):
                        crf_label_probs[i, j, label_id] = 1
                flattened_probs = crf_label_probs.view(
                    (batch_size * num_sentences), self.num_labels)

            if not self.labels_are_scores:
                # Kacper: this will be a case for us as well because labels are numerical for Pubmed data
                evaluation_mask = (flattened_gold != -1)
                # Kacper: CategoricalAccuracy is computed in this case
                self.label_accuracy(flattened_probs.float().contiguous(),
                                    flattened_gold.squeeze(-1),
                                    mask=evaluation_mask)

                # compute F1 per label
                for label_index in range(self.num_labels):
                    label_name = self.vocab.get_token_from_index(
                        namespace='labels', index=label_index)
                    metric = self.label_f1_metrics[label_name]
                    metric(flattened_probs,
                           flattened_gold,
                           mask=evaluation_mask)

        if labels is not None:
            output_dict["loss"] = label_loss
        output_dict['action_logits'] = label_logits
        return output_dict

    def get_metrics(self, reset: bool = False):
        # Kacper: this function has to implemented due to API requirements for AllenNLP
        # Kacper: so it can be run automatically with a config file
        metric_dict = {}

        if not self.labels_are_scores:
            type_accuracy = self.label_accuracy.get_metric(reset)
            metric_dict['acc'] = type_accuracy

            average_F1 = 0.0
            for name, metric in self.label_f1_metrics.items():
                metric_val = metric.get_metric(reset)
                metric_dict[name + 'F'] = metric_val[2]
                average_F1 += metric_val[2]

            average_F1 /= len(self.label_f1_metrics.items())
            metric_dict['avgF'] = average_F1

        return metric_dict
Beispiel #7
0
class DualCrossSharedRNN(nn.Module):
    def __init__(self,
                 general_embeddings,
                 domain_embeddings,
                 input_size,
                 hidden_size,
                 aspect_tag_classes,
                 polarity_tag_classes,
                 k,
                 dropout=0.5):
        super(DualCrossSharedRNN, self).__init__()
        self.general_embedding = nn.Embedding(
            num_embeddings=general_embeddings.size(0),
            embedding_dim=general_embeddings.size(1),
            padding_idx=0).from_pretrained(general_embeddings)
        self.domain_embedding = nn.Embedding(
            num_embeddings=domain_embeddings.size(0),
            embedding_dim=domain_embeddings.size(1),
            padding_idx=0).from_pretrained(domain_embeddings)
        self.general_embedding.weight.requires_grad = False
        self.domain_embedding.weight.requires_grad = False
        self.dropout = dropout
        self.hidden_size = hidden_size
        self.aspect_rnn1 = ReGU(input_size,
                                hidden_size,
                                num_layers=1,
                                bidirectional=True)
        self.polarity_rnn1 = ReGU(input_size,
                                  hidden_size,
                                  num_layers=1,
                                  bidirectional=True)
        self.csu = Cross_Shared_Unit(k, 2 * hidden_size)
        self.aspect_rnn2 = ReGU(2 * hidden_size,
                                hidden_size,
                                num_layers=1,
                                bidirectional=True)
        self.polarity_rnn2 = ReGU(2 * hidden_size,
                                  hidden_size,
                                  num_layers=1,
                                  bidirectional=True)
        self.aspect_hidden2tag = nn.Linear(2 * hidden_size, aspect_tag_classes)
        self.polarity_hidden2tag = nn.Linear(2 * hidden_size,
                                             polarity_tag_classes)
        self.aspect_crf = ConditionalRandomField(aspect_tag_classes)
        self.polarity_crf = ConditionalRandomField(polarity_tag_classes)
        self.dropout_layer = nn.Dropout(dropout)

    def forward(self,
                features,
                aspect_tags,
                polarity_tags,
                mask,
                testing=False,
                crf=True):
        batch = features.size(0)
        general_features = self.general_embedding(features)
        domain_features = self.domain_embedding(features)
        features = torch.cat((general_features, domain_features), dim=2)
        states = torch.zeros(1, 2, batch, self.hidden_size).to(features.device)
        features = self.dropout_layer(features)
        aspect_hidden, _ = self.aspect_rnn1(features, states)
        polarity_hidden, _ = self.polarity_rnn1(features, states)
        aspect_hidden, polarity_hidden = self.csu(aspect_hidden,
                                                  polarity_hidden,
                                                  max_pooling=False)
        aspect_hidden, _ = self.aspect_rnn2(aspect_hidden, states)
        polarity_hidden, _ = self.polarity_rnn2(polarity_hidden, states)
        aspect_logit = self.aspect_hidden2tag(aspect_hidden)
        polarity_logit = self.polarity_hidden2tag(polarity_hidden)
        if crf == True:
            if testing == False:
                aspect_score = -self.aspect_crf(aspect_logit, aspect_tags,
                                                mask)
                polarity_score = -self.polarity_crf(polarity_logit,
                                                    polarity_tags, mask)
                return aspect_score + polarity_score
            else:
                aspect_path = self.aspect_crf.viterbi_tags(aspect_logit, mask)
                polarity_path = self.polarity_crf.viterbi_tags(
                    polarity_logit, mask)
                return aspect_path, polarity_path
        else:
            return aspect_logit, polarity_logit
Beispiel #8
0
class JointClassifier(Model):
    """
    Classifies NER tags and RE classes jointly. Label encoding is expected to be 'BIO'.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the ``tokens`` ``TextField`` we get as input to the model.
    ner_tag_embedder : ``Embedding``, required
        Used to embed decoded ner tags as input to the relation scorer.
    encoder : ``Seq2SeqEncoder``
        An encoder that will learn the major logic of the task.
    relation_scorer : ``RelationScorer``
        A subtask model, that performs scoring of relations between entities.
    ner_tag_namespace : ``str``
        The vocabulary namespace of ner tags.
    evaluated_ner_labels : ``List[str]``, optional (default=``None``)
        The list of ner tag types that are to be used for f1 score computation.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 relation_scorer: RelationScorer,
                 ner_tag_namespace: str = 'tags',
                 evaluated_ner_labels: List[str] = None,
                 re_loss_weight: float = 1.0,
                 ner_tag_embedder: TokenEmbedder = None,
                 use_aux_ner_labels: bool = False,
                 aux_coarse_namespace: str = 'coarse_tags',
                 aux_modifier_namespace: str = 'modifier_tags',
                 aux_loss_weight: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab=vocab, regularizer=regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        # NER subtask 2
        self._ner_label_encoding = 'BIO'
        self._ner_tag_namespace = ner_tag_namespace
        ner_input_dim = self.encoder.get_output_dim()
        num_ner_tags = self.vocab.get_vocab_size(ner_tag_namespace)
        self.tag_projection_layer = TimeDistributed(
            Linear(ner_input_dim, num_ner_tags))

        self._use_aux_ner_labels = use_aux_ner_labels
        if self._use_aux_ner_labels:
            self._coarse_tag_namespace = aux_coarse_namespace
            self._num_coarse_tags = self.vocab.get_vocab_size(
                self._coarse_tag_namespace)
            self._coarse_projection_layer = TimeDistributed(
                Linear(ner_input_dim, self._num_coarse_tags))
            self._modifier_tag_namespace = aux_modifier_namespace
            self._num_modifier_tags = self.vocab.get_vocab_size(
                self._modifier_tag_namespace)
            self._modifier_projection_layer = TimeDistributed(
                Linear(ner_input_dim, self._num_modifier_tags))
            self._coarse_acc = CategoricalAccuracy()
            self._modifier_acc = CategoricalAccuracy()
            self._aux_loss_weight = aux_loss_weight

        self.ner_accuracy = CategoricalAccuracy()
        if evaluated_ner_labels is None:
            ignored_classes = None
        else:
            assert self._ner_label_encoding == 'BIO', 'expected BIO encoding'
            all_ner_tags = self.vocab.get_token_to_index_vocabulary(
                ner_tag_namespace).keys()
            ner_tag_classes = set(
                [bio_tag[2:] for bio_tag in all_ner_tags if len(bio_tag) > 2])
            ignored_classes = list(
                set(ner_tag_classes).difference(evaluated_ner_labels))
        self.ner_f1 = SpanBasedF1Measure(
            vocabulary=vocab,
            tag_namespace=ner_tag_namespace,
            label_encoding=self._ner_label_encoding,
            ignore_classes=ignored_classes)

        # Use constrained crf decoding with the BIO labeling scheme
        ner_labels = self.vocab.get_index_to_token_vocabulary(
            ner_tag_namespace)
        constraints = allowed_transitions(self._ner_label_encoding, ner_labels)

        self.crf = ConditionalRandomField(num_ner_tags,
                                          constraints,
                                          include_start_end_transitions=True)

        # RE subtask 3
        self.ner_tag_embedder = ner_tag_embedder
        self.relation_scorer = relation_scorer
        self._re_loss_weight = re_loss_weight

        initializer(self)

    @overrides
    def forward(
            self,
            tokens: Dict[str, torch.LongTensor],
            tags: torch.LongTensor = None,
            relation_root_idxs: torch.LongTensor = None,
            relations: torch.LongTensor = None,
            binary_coref: torch.FloatTensor = None,
            spacy_patterns: torch.FloatTensor = None,
            coarse_tags: torch.LongTensor = None,
            modifier_tags: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ,no-member
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : torch.LongTensor
            An integer tensor containing the gold ner tag label indexes.
        relation_root_idxs : torch.LongTensor, optional (default = None)
            An integer tensor containing the gold relation head indexes for training.
        relations : torch.LongTensor, optional (default = None)
            An integer tensor containing the gold relation label indexes for training.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Additional information such as the original words and the entity ids.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        batch_size, sequence_length, _ = embedded_text_input.size()
        mask = get_text_field_mask(tokens)

        encoder_input_tensors = [embedded_text_input]
        if binary_coref is not None:
            encoder_input_tensors.append(binary_coref.unsqueeze(2))
        if spacy_patterns is not None:
            encoder_input_tensors.append(spacy_patterns.permute(0, 2, 1))
        if len(encoder_input_tensors) > 1:
            encoder_input = torch.cat(encoder_input_tensors, dim=2)
        else:
            encoder_input = encoder_input_tensors[0]

        # Shape: batch x seq_len x emb_dim
        encoded_text = self.encoder(encoder_input, mask)

        ner_logits = self.tag_projection_layer(encoded_text)
        best_ner_paths = self.crf.viterbi_tags(ner_logits, mask)

        # Just get the tags and ignore the score.
        predicted_ner_tags = []
        predicted_ner_tags_tensor = torch.zeros_like(mask)
        for ner_path, _ in best_ner_paths:
            batch_idx = len(predicted_ner_tags)
            predicted_ner_tags.append(ner_path)
            for token_idx, ner_tag_idx in enumerate(ner_path):
                predicted_ner_tags_tensor[batch_idx, token_idx] = ner_tag_idx
        # predicted_ner_tags = [x for x, y in best_ner_paths]

        output_dict = {
            "ner_logits": ner_logits,
            "mask": mask,
            "tags": predicted_ner_tags
        }

        if self._use_aux_ner_labels:
            coarse_logits = self._coarse_projection_layer(encoded_text)
            modifier_logits = self._modifier_projection_layer(encoded_text)

        if self.ner_tag_embedder is not None:
            embedded_tags = self.ner_tag_embedder(predicted_ner_tags_tensor)
            encoded_sequence = torch.cat([encoded_text, embedded_tags], dim=2)
        else:
            encoded_sequence = torch.cat([
                encoded_text, ner_logits,
                predicted_ner_tags_tensor.unsqueeze(2).float()
            ],
                                         dim=2)

        re_output = self.relation_scorer(encoded_sequence, mask,
                                         relation_root_idxs, relations)

        # Add a prefix for relation extraction logits
        output_dict['re_logits'] = re_output['logits']
        output_dict['relation_scores'] = re_output['relation_scores']

        if tags is not None:
            # Add negative log-likelihood as loss
            log_likelihood = self.crf(ner_logits, tags, mask)

            # It's not clear why, but pylint seems to think `log_likelihood` is tuple
            # (in fact, it's a torch.Tensor), so we need a disable.
            output_dict["ner_loss"] = -log_likelihood  # pylint: disable=invalid-unary-operand-type

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the metrics
            class_probabilities = torch.zeros_like(ner_logits)
            for i, instance_tags in enumerate(predicted_ner_tags):
                for j, tag_id in enumerate(instance_tags):
                    class_probabilities[i, j, tag_id] = 1

            self.ner_accuracy(class_probabilities, tags, mask.float())
            self.ner_f1(class_probabilities, tags, mask.float())

            output_dict['loss'] = output_dict[
                'ner_loss'] + self._re_loss_weight * re_output['loss']

            if self._use_aux_ner_labels:
                assert coarse_tags is not None and modifier_tags is not None, 'Auxiliary losses require auxiliary input'
                self._coarse_acc(coarse_logits, coarse_tags, mask.float())
                self._modifier_acc(modifier_logits, modifier_tags,
                                   mask.float())
                coarse_loss = sequence_cross_entropy_with_logits(
                    coarse_logits, coarse_tags, mask)
                modifier_loss = sequence_cross_entropy_with_logits(
                    modifier_logits, modifier_tags, mask)
                output_dict['loss'] += self._aux_loss_weight * (coarse_loss +
                                                                modifier_loss)

        # Attach metadata
        if metadata is not None:
            for key in metadata[0]:
                output_dict[key] = [x[key] for x in metadata]

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        output_dict = self.relation_scorer.decode(output_dict)
        # for key in ['relations', 'heads', 'head_offsets']:
        #     if key in re_output_dict:
        #         output_dict[key] = re_output_dict[key]
        output_dict["tags"] = [[
            self.vocab.get_token_from_index(tag, self._ner_tag_namespace)
            for tag in instance_tags
        ] for instance_tags in output_dict["tags"]]
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        re_metrics = self.relation_scorer.get_metrics(reset=reset)
        joint_metrics = {
            'ner_acc': self.ner_accuracy.get_metric(reset=reset),
            'ner_f1':
            self.ner_f1.get_metric(reset=reset)['f1-measure-overall'],
            're_acc': re_metrics['re_acc'],
        }
        if 're_f1' in re_metrics:
            joint_metrics['re_f1'] = re_metrics['re_f1']
        if self._use_aux_ner_labels:
            joint_metrics['coarse_acc'] = self._coarse_acc.get_metric(
                reset=reset)
            joint_metrics['modifier_acc'] = self._modifier_acc.get_metric(
                reset=reset)
        return joint_metrics
Beispiel #9
0
class ScienceIETagger(nn.Module, ClassNursery):
    def __init__(
        self,
        rnn2seqencoder: Lstm2SeqEncoder,
        hid_dim: int,
        num_classes: int,
        device: torch.device = torch.device("cpu"),
        task_constraints: Optional[List[Tuple[int, int]]] = None,
        process_constraints: Optional[List[Tuple[int, int]]] = None,
        material_constraints: Optional[List[Tuple[int, int]]] = None,
        character_encoder: Optional[LSTM2VecEncoder] = None,
        include_start_end_transitions: Optional[bool] = False,
    ):
        """

        Parameters
        ----------
        rnn2seqencoder : Lstm2SeqEncoder
            rnn2seq enncoder that encodes instances to a sequence of hidden states
        hid_dim : int
            Hidden dimension of the lstm2seq encoder
        num_classes : int
            The number of classes for every token n
        device : torch.device
            The device on which the model should be run
        task_constraints : List[Tuple[int, int]]
            A set of constraints  that indicates
            valid transitions between states
        process_constraints : List[Tuple[int, int]]
            A set of constraints that indicates
            valid transitions between states
        material_constraints : List[Tuple[int, int]]
            A set of constraints that indicates
            valid transitions between states
        character_encoder : LSTM2VecEncoder that encodes
        the characters into a vector
        include_start_end_transitions : bool
            whether to include start and end transitions
        """
        super(ScienceIETagger, self).__init__()
        self.rnn2seqencoder = rnn2seqencoder
        self.hid_dim = hid_dim
        self.num_classes = num_classes
        self.device = device
        self._task_constraints = task_constraints
        self._process_constraints = process_constraints
        self._material_constraints = material_constraints
        self.character_encoder = character_encoder
        self.include_start_end_transitions = include_start_end_transitions

        self.task_crf = CRF(
            num_tags=self.num_classes,
            constraints=task_constraints,
            include_start_end_transitions=include_start_end_transitions,
        )
        self.process_crf = CRF(
            num_tags=self.num_classes,
            constraints=process_constraints,
            include_start_end_transitions=include_start_end_transitions,
        )
        self.material_crf = CRF(
            num_tags=self.num_classes,
            constraints=material_constraints,
            include_start_end_transitions=include_start_end_transitions,
        )

        self.hidden2task = nn.Linear(self.hid_dim, self.num_classes)
        self.hidden2process = nn.Linear(self.hid_dim, self.num_classes)
        self.hidden2material = nn.Linear(self.hid_dim, self.num_classes)

    def forward(
        self,
        iter_dict: Dict[str, Any],
        is_training: bool,
        is_validation: bool,
        is_test: bool,
    ):
        encoding = self.rnn2seqencoder(iter_dict=iter_dict)

        # batch_size * time_steps * num_classes
        task_logits = self.hidden2task(encoding)
        process_logits = self.hidden2process(encoding)
        material_logits = self.hidden2material(encoding)

        batch_size, time_steps, _ = task_logits.size()
        mask = torch.ones(size=(batch_size, time_steps), dtype=torch.long)
        mask = torch.LongTensor(mask)
        mask = mask.to(self.device)

        assert task_logits.size(1) == process_logits.size(
            1) == material_logits.size(1)
        assert task_logits.size(2) == process_logits.size(
            2) == material_logits.size(2)

        # List[List[int]] N * T
        predicted_task_tags = self.task_crf.viterbi_tags(logits=task_logits,
                                                         mask=mask)
        predicted_process_tags = self.process_crf.viterbi_tags(
            logits=process_logits, mask=mask)
        predicted_material_tags = self.material_crf.viterbi_tags(
            logits=material_logits, mask=mask)

        predicted_task_tags = [tag for tag, _ in predicted_task_tags]
        predicted_process_tags = [tag for tag, _ in predicted_process_tags]
        predicted_material_tags = [tag for tag, _ in predicted_material_tags]

        # add the appropriate numbers
        predicted_task_tags = torch.LongTensor(predicted_task_tags)
        predicted_process_tags = torch.LongTensor(predicted_process_tags) + 8
        predicted_material_tags = torch.LongTensor(
            predicted_material_tags) + 16

        assert (len(predicted_task_tags) == len(predicted_process_tags) ==
                len(predicted_material_tags))
        # arrange the labels in N * 3T
        predicted_tags = torch.cat(
            [
                predicted_task_tags, predicted_process_tags,
                predicted_material_tags
            ],
            dim=1,
        )
        predicted_tags = predicted_tags.tolist()

        output_dict = {
            "task_logits": task_logits,
            "process_logits": process_logits,
            "material_logits": material_logits,
            "predicted_task_tags": predicted_task_tags.tolist(),
            "predicted_process_tags": predicted_process_tags.tolist(),
            "predicted_material_tags": predicted_material_tags.tolist(),
            "predicted_tags": predicted_tags,
        }

        if is_training or is_validation:
            labels = iter_dict["label"]
            len_tokens = iter_dict["len_tokens"]
            mask = get_mask(batch_size=batch_size,
                            max_size=time_steps,
                            lengths=len_tokens)
            mask = mask.to(self.device)
            # if you change label then iter_dict["label"] gets screwed
            labels_copy = copy.deepcopy(labels)
            assert labels.ndimension() == 2, self.msg_printer(
                f"For Science IE Tagger, labels should have 2 dimensions"
                f"batch_size, 3 * max_length. The labels you passed have "
                f"{labels.ndimension()}")

            task_labels, process_labels, material_labels = torch.chunk(
                labels_copy, chunks=3, dim=1)
            process_labels -= 8
            material_labels -= 16
            task_loss = -self.task_crf(task_logits, task_labels, mask)
            process_loss = -self.process_crf(process_logits, process_labels,
                                             mask)
            material_loss = -self.material_crf(material_logits,
                                               material_labels, mask)
            loss = task_loss + process_loss + material_loss
            output_dict["loss"] = loss

        return output_dict
Beispiel #10
0
class BertMiddleModel(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: str,
        dropout: float = 0.0,
        requires_grad: str = "none",
        use_crf: bool = False,
        pos_weight: float = 1.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ):

        super(BertMiddleModel, self).__init__(vocab, regularizer)
        self._vocabulary = vocab
        self._bert_model = BertModel.from_pretrained(bert_model)
        self._dropout = torch.nn.Dropout(p=dropout)
        self._classification_layer = torch.nn.Linear(
            self._bert_model.config.hidden_size, 2)

        self._use_crf = use_crf

        self._pos_weight = torch.Tensor([1 / (1 - pos_weight), 1 / pos_weight])
        self._pos_weight = torch.nn.Parameter(self._pos_weight /
                                              self._pos_weight.min())
        self._pos_weight.requires_grad = False

        if use_crf:
            self._crf = ConditionalRandomField(num_tags=2)

        self.embedding_layers = ["BertEmbedding"]

        if requires_grad in ["none", "all"]:
            for param in self._bert_model.parameters():
                param.requires_grad = requires_grad == "all"
        else:
            model_name_regexes = requires_grad.split(",")
            for name, param in self._bert_model.named_parameters():
                found = any([regex in name for regex in model_name_regexes])
                param.requires_grad = found

        for n, v in self._bert_model.named_parameters():
            if n.startswith("classifier"):
                v.requires_grad = True

        self._token_prf = F1Measure(1)

        initializer(self)

    def forward(self,
                document,
                query=None,
                rationale=None,
                metadata=None,
                label=None) -> Dict[str, Any]:
        input_ids = document["bert"]
        input_mask = (input_ids != 0).long()
        starting_offsets = document["bert-starting-offsets"]  # (B, T)

        last_hidden_states, _ = self._bert_model(
            input_ids,
            attention_mask=input_mask,
            position_ids=document["bert-position-ids"])

        token_embeddings, span_mask = generate_embeddings_for_pooling(
            last_hidden_states, starting_offsets,
            document["bert-ending-offsets"])

        token_embeddings = util.masked_max(token_embeddings,
                                           span_mask.unsqueeze(-1),
                                           dim=2)
        token_embeddings = token_embeddings * document["mask"].unsqueeze(-1)

        logits = self._classification_layer(self._dropout(token_embeddings))
        assert logits.shape[0:2] == starting_offsets.shape

        if self._use_crf:
            best_paths = self._crf.viterbi_tags(logits, mask=document["mask"])
            best_paths = [b[0] for b in best_paths]
            best_paths = [
                x + [0] * (logits.shape[1] - len(x)) for x in best_paths
            ]
            best_paths = torch.Tensor(best_paths).to(
                logits.device) * document["mask"]
        else:
            best_paths = (logits[:, :, 1] > 0.5).long() * document["mask"]

        output_dict = {}

        output_dict["predicted_rationales"] = best_paths
        output_dict["mask"] = document["mask"]
        output_dict["metadata"] = metadata

        if rationale is not None:
            if self._use_crf:
                output_dict["loss"] = -self._crf(logits, rationale,
                                                 document["mask"])
            else:
                output_dict["loss"] = ((F.cross_entropy(
                    logits.view(-1, logits.shape[-1]),
                    rationale.view(-1),
                    reduction="none",
                    weight=self._pos_weight,
                ) * document["mask"].view(-1)).sum(-1).mean())

            best_paths = best_paths.unsqueeze(-1)
            best_paths = torch.cat([1 - best_paths, best_paths], dim=-1)
            self._token_prf(best_paths, rationale, document["mask"])
        return output_dict

    def extract_rationale(self, output_dict):
        rationales = []
        sentences = [x["tokens"] for x in output_dict["metadata"]]
        predicted_rationales = output_dict["predicted_rationales"].cpu(
        ).data.numpy()
        for path, words in zip(predicted_rationales, sentences):
            path = list(path)[:len(words)]
            words = [x.text for x in words]
            starts, ends = [], []
            path.append(0)
            for i in range(len(words)):
                if path[i - 1:i] == [0, 1]:
                    starts.append(i)
                if path[i - 1:i] == [1, 0]:
                    ends.append(i)

            assert len(starts) == len(ends)
            spans = list(zip(starts, ends))

            rationales.append({
                "document":
                " ".join([w for i, w in zip(path, words) if i == 1]),
                "spans": [{
                    "span": (s, e),
                    "value": 1
                } for s, e in spans],
                "metadata":
                None,
            })

        return rationales

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = self._token_prf.get_metric(reset)
        return dict(zip(["p", "r", "f1"], metrics))

    def decode(self, output_dict):
        rationales = self.extract_rationale(output_dict)
        new_output_dict = {}

        new_output_dict['rationale'] = rationales
        new_output_dict['document'] = [r['document'] for r in rationales]

        if 'query' in output_dict['metadata'][0]:
            output_dict['query'] = [
                m['query'] for m in output_dict['metadata']
            ]

        for m in output_dict["metadata"]:
            if 'convert_tokens_to_instance' in m:
                del m["convert_tokens_to_instance"]

        new_output_dict['label'] = [
            m['label'] for m in output_dict['metadata']
        ]
        new_output_dict['metadata'] = output_dict['metadata']

        return new_output_dict
Beispiel #11
0
class DualCrossSharedLSTM(nn.Module):
    def __init__(self, general_embeddings, domain_embeddings, input_size, hidden_size, aspect_tag_classes,
                 polarity_tag_classes, k, dropout=0.5):
        super(DualCrossSharedLSTM, self).__init__()
        self.general_embedding = nn.Embedding(num_embeddings=general_embeddings.size(0),
                                              embedding_dim=general_embeddings.size(1),
                                              padding_idx=0).from_pretrained(general_embeddings)
        self.domain_embedding = nn.Embedding(num_embeddings=domain_embeddings.size(0),
                                             embedding_dim=domain_embeddings.size(1),
                                             padding_idx=0).from_pretrained(domain_embeddings)
        self.general_embedding.weight.requires_grad = False
        self.domain_embedding.weight.requires_grad = False

        self.dropout = dropout

        self.aspect_rnn1 =  DynamicRNN(input_size,hidden_size, num_layers=1, batch_first=True, bidirectional=True)
        self.polarity_rnn1 = DynamicRNN(input_size,hidden_size, num_layers=1, batch_first=True, bidirectional=True)

        self.csu = Cross_Shared_Unit(k, 2 * hidden_size)

        self.aspect_rnn2 =  DynamicRNN(hidden_size*2,hidden_size, num_layers=1, batch_first=True, bidirectional=True)
        self.polarity_rnn2 = DynamicRNN(hidden_size*2,hidden_size, num_layers=1, batch_first=True, bidirectional=True)

        self.aspect_hidden2tag = nn.Linear(2 * hidden_size, aspect_tag_classes)
        self.polarity_hidden2tag = nn.Linear(2 * hidden_size, polarity_tag_classes)

        self.aspect_crf = ConditionalRandomField(aspect_tag_classes)
        self.polarity_crf = ConditionalRandomField(polarity_tag_classes)

        self.dropout_layer = nn.Dropout(dropout)

    def forward(self, features, aspect_tags, polarity_tags, mask,lengths, testing=False):
        general_features = self.general_embedding(features)
        domain_features = self.domain_embedding(features)
        features = torch.cat((general_features, domain_features), dim=2)
        max_len = torch.max(lengths)
        mask = mask[:,:max_len]
        features = features[:,:max_len,:]
        features = self.dropout_layer(features)
        aspect_hidden,(hn,cn) = self.aspect_rnn1(features,lengths)
        polarity_hidden,(hn,cn) = self.polarity_rnn1(features,lengths)

        #CSU
        aspect_hidden, polarity_hidden = self.csu(aspect_hidden, polarity_hidden, max_pooling=False)

        aspect_hidden_,(hn,cn) = self.aspect_rnn2(aspect_hidden,lengths)
        polarity_hidden_,(hn,cn) = self.polarity_rnn2(polarity_hidden,lengths)

        #res

        aspect_hidden = aspect_hidden_ + aspect_hidden
        polarity_hidden = polarity_hidden_ + polarity_hidden

        aspect_logit = self.aspect_hidden2tag(aspect_hidden)
        polarity_logit = self.polarity_hidden2tag(polarity_hidden)

        if testing == False:
            aspect_score = -self.aspect_crf(aspect_logit, aspect_tags, mask)
            polarity_score = -self.polarity_crf(polarity_logit, polarity_tags, mask)
            return aspect_score + polarity_score
        else:
            aspect_path = self.aspect_crf.viterbi_tags(aspect_logit, mask)
            polarity_path = self.polarity_crf.viterbi_tags(polarity_logit, mask)
            return aspect_path, polarity_path
Beispiel #12
0
class EmbeddingClfHead(T.BaseClfHead):
    def __init__(self,
                 config,
                 lm_model,
                 lm_config,
                 embed_type='w2v',
                 w2v_path=None,
                 iactvtn='relu',
                 oactvtn='sigmoid',
                 fchdim=0,
                 extfc=False,
                 sample_weights=False,
                 num_lbs=1,
                 mlt_trnsfmr=False,
                 lm_loss=False,
                 do_drop=True,
                 pdrop=0.2,
                 do_norm=True,
                 norm_type='batch',
                 do_lastdrop=True,
                 do_crf=False,
                 do_thrshld=False,
                 constraints=[],
                 initln=False,
                 initln_mean=0.,
                 initln_std=0.02,
                 task_params={},
                 **kwargs):
        from util import config as C
        super(EmbeddingClfHead, self).__init__(
            config,
            lm_model,
            lm_config,
            sample_weights=sample_weights,
            num_lbs=num_lbs,
            mlt_trnsfmr=config.task_type in ['entlmnt', 'sentsim']
            and task_params.setdefault('sentsim_func', None) is not None,
            task_params=task_params,
            **kwargs)
        self.dim_mulriple = 2 if self.task_type in ['entlmnt', 'sentsim'] and (
            self.task_params.setdefault('sentsim_func', None) is None
            or self.task_params['sentsim_func'] == 'concat') else 1
        self.embed_type = embed_type
        if embed_type.startswith('w2v'):
            from gensim.models import KeyedVectors
            from gensim.models.keyedvectors import Word2VecKeyedVectors
            self.w2v_model = w2v_path if type(
                w2v_path) is Word2VecKeyedVectors else (
                    KeyedVectors.load(w2v_path, mmap='r')
                    if w2v_path and os.path.isfile(w2v_path) else None)
            assert (self.w2v_model)
            self.n_embd = self.w2v_model.syn0.shape[1] + (
                self.n_embd if hasattr(self, 'n_embd') else 0)
            config.register_callback(
                'mdl_trsfm', EmbeddingClfHead.callback_update_w2v_model(self))
        elif embed_type.startswith('elmo'):
            self.vocab_size = 793471
            self.n_embd = lm_config['elmoedim'] * 2 + (
                self.n_embd if hasattr(self, 'n_embd') else 0
            )  # two ELMo layer * ELMo embedding dimensions
            config.register_callback(
                'mdl_trsfm',
                EmbeddingClfHead.callback_update_elmo_config(self))
        elif embed_type.startswith('elmo_w2v'):
            from gensim.models import KeyedVectors
            from gensim.models.keyedvectors import Word2VecKeyedVectors
            self.w2v_model = w2v_path if type(
                w2v_path) is Word2VecKeyedVectors else (
                    KeyedVectors.load(w2v_path, mmap='r')
                    if w2v_path and os.path.isfile(w2v_path) else None)
            assert (self.w2v_model)
            self.vocab_size = 793471
            self.n_embd = self.w2v_model.syn0.shape[
                1] + lm_config['elmoedim'] * 2 + (self.n_embd if hasattr(
                    self, 'n_embd') else 0)
            config.register_callback(
                'mdl_trsfm', EmbeddingClfHead.callback_update_w2v_model(self))
            config.register_callback(
                'mdl_trsfm',
                EmbeddingClfHead.callback_update_elmo_config(self))
        self.norm = C.NORM_TYPE_MAP[norm_type](
            self.maxlen
        ) if self.task_type == 'nmt' else C.NORM_TYPE_MAP[norm_type](
            self.n_embd)
        self._int_actvtn = C.ACTVTN_MAP[iactvtn]
        self._out_actvtn = C.ACTVTN_MAP[oactvtn]
        self.fchdim = fchdim
        self.extfc = extfc
        self.hdim = self.dim_mulriple * self.n_embd if self.mlt_trnsfmr and self.task_type in [
            'entlmnt', 'sentsim'
        ] else self.n_embd
        self.linear = self.__init_linear__()
        if (initln):
            self.linear.apply(H._weights_init(mean=initln_mean,
                                              std=initln_std))
        if self.do_extlin:
            self.extlinear = nn.Linear(self.n_embd, self.n_embd)
            if (initln):
                self.extlinear.apply(
                    H._weights_init(mean=initln_mean, std=initln_std))
        self.crf = ConditionalRandomField(num_lbs) if do_crf else None

    def __init_linear__(self):
        use_gpu = next(self.parameters()).is_cuda
        linear = (
            nn.Sequential(
                nn.Linear(self.hdim, self.fchdim), self._int_actvtn(),
                nn.Linear(self.fchdim, self.fchdim), self._int_actvtn(),
                *([] if self.task_params.setdefault('sentsim_func', None)
                  and self.task_params['sentsim_func'] != 'concat' else
                  [nn.Linear(self.fchdim, self.num_lbs),
                   self._out_actvtn()]))
            if self.task_type in ['entlmnt', 'sentsim'] else nn.Sequential(
                nn.Linear(self.hdim, self.fchdim), self._int_actvtn(),
                nn.Linear(self.fchdim, self.fchdim), self._int_actvtn(),
                nn.Linear(self.fchdim, self.num_lbs))
        ) if self.fchdim else (nn.Sequential(*(
            [nn.Linear(self.hdim, self.hdim
                       ), self._int_actvtn()]
            if self.task_params.setdefault('sentsim_func', None)
            and self.task_params['sentsim_func'] != 'concat' else
            [nn.Linear(self.hdim, self.num_lbs),
             self._out_actvtn()])) if self.task_type in ['entlmnt', 'sentsim']
                               else nn.Linear(self.hdim, self.num_lbs))
        return linear.to('cuda') if use_gpu else linear

    def __lm_head__(self):
        return EmbeddingHead(self)

    def _w2v(self, input_ids, use_gpu=False):
        wembd_tnsr = torch.tensor([self.w2v_model.syn0[s] for s in input_ids])
        if use_gpu: wembd_tnsr = wembd_tnsr.to('cuda')
        return wembd_tnsr

    def _sentvec(self, input_ids, use_gpu=False):
        pass

    def forward(self,
                input_ids,
                *extra_inputs,
                labels=None,
                past=None,
                weights=None,
                embedding_mode=False,
                ret_mask=False):
        use_gpu = next(self.parameters()).is_cuda
        if self.sample_weights and len(extra_inputs) > 0:
            sample_weights = extra_inputs[-1]
            extra_inputs = extra_inputs[:-1]
        else:
            sample_weights = None
        unsolved_input_keys, unsolved_inputs = self.embed_type.split(
            '_'), [input_ids] + list(extra_inputs)
        extra_inputs_dict = dict(
            zip([x for x in self.input_keys if x != 'input_ids'],
                extra_inputs))
        pool_idx = extra_inputs_dict['mask'].sum(1)
        mask = extra_inputs_dict['mask']  # mask of the original textual input
        clf_hs = []
        if self.task_type in ['entlmnt', 'sentsim']:
            if (self.embed_type.startswith('elmo')):
                embeddings = (self.lm_model(input_ids[0]),
                              self.lm_model(input_ids[1]))
                clf_hs.append((torch.cat(embeddings[0]['elmo_representations'],
                                         dim=-1),
                               torch.cat(embeddings[1]['elmo_representations'],
                                         dim=-1)))
                del unsolved_input_keys[0]
                del unsolved_inputs[0]
            for input_key, input_tnsr in zip(unsolved_input_keys,
                                             unsolved_inputs):
                clf_hs.append([
                    getattr(self, '_%s' % input_key)(input_tnsr[x],
                                                     use_gpu=use_gpu)
                    for x in [0, 1]
                ])
            clf_h = [torch.cat(embds, dim=-1) for embds in zip(*clf_hs)]
        else:
            if (self.embed_type.startswith('elmo')):
                embeddings = self.lm_model(input_ids)
                clf_hs.append(
                    torch.cat(embeddings['elmo_representations'], dim=-1))
                del unsolved_input_keys[0]
                del unsolved_inputs[0]
            for input_key, input_tnsr in zip(unsolved_input_keys,
                                             unsolved_inputs):
                clf_hs.append(
                    getattr(self, '_%s' % input_key)(input_tnsr,
                                                     use_gpu=use_gpu))
            clf_h = torch.cat(clf_hs, dim=-1)
        if labels is None:
            return (clf_h, mask) if ret_mask else (clf_h, )
        # Calculate language model loss
        if (self.lm_loss):
            lm_logits, lm_target = self.lm_logit(input_ids, clf_h,
                                                 extra_inputs_dict)
            lm_loss_func = nn.CrossEntropyLoss(ignore_index=-1,
                                               reduction='none')
            lm_loss = lm_loss_func(
                lm_logits.contiguous().view(-1, lm_logits.size(-1)),
                lm_target.contiguous().view(-1)).view(input_ids.size(0), -1)
            if sample_weights is not None: lm_loss *= sample_weights
        else:
            lm_loss = None
        return (clf_h, lm_loss, mask) if ret_mask else (clf_h, lm_loss)

    def _forward(self,
                 clf_h,
                 mask,
                 labels=None,
                 weights=None):  # For fine-tune task
        if self.task_type in ['entlmnt', 'sentsim']:
            if self.do_norm: clf_h = [self.norm(clf_h[x]) for x in [0, 1]]
            clf_h = [self.dropout(clf_h[x]) for x in [0, 1]]
            if (self.task_type == 'entlmnt' or
                    self.task_params.setdefault('sentsim_func', None) is None
                    or self.task_params['sentsim_func'] == 'concat'):
                if task_params.setdefault('concat_strategy',
                                          'normal') == 'diff':
                    clf_h = torch.cat(
                        clf_h +
                        [torch.abs(clf_h[0] - clf_h[1]), clf_h[0] * clf_h[1]],
                        dim=-1)
                elif task_params.setdefault('concat_strategy',
                                            'normal') == 'flipflop':
                    clf_h = (torch.cat(clf_h, dim=-1) +
                             torch.cat(clf_h[::-1], dim=-1))
                else:
                    clf_h = torch.cat(clf_h, dim=-1)
                clf_logits = self.linear(clf_h) if self.linear else clf_h
            else:
                clf_logits = clf_h = F.pairwise_distance(
                    self.linear(clf_h[0]), self.linear(
                        clf_h[1]), 2, eps=1e-12) if self.task_params[
                            'sentsim_func'] == 'dist' else F.cosine_similarity(
                                self.linear(clf_h[0]),
                                self.linear(clf_h[1]),
                                dim=1,
                                eps=1e-12)
        else:
            if self.do_norm: clf_h = self.norm(clf_h)
            clf_h = self.dropout(clf_h)
            clf_logits = self.linear(clf_h)
            if self.do_lastdrop: clf_logits = self.last_dropout(clf_logits)

        if (labels is None):
            if self.crf:
                tag_seq, score = zip(*self.crf.viterbi_tags(
                    clf_logits.view(input_ids.size()[0], -1, self.num_lbs),
                    torch.ones(*(input_ids.size()[:2])).int()))
                tag_seq = torch.tensor(tag_seq).to(
                    'cuda') if use_gpu else torch.tensor(tag_seq)
                clf_logits = torch.zeros(
                    (*tag_seq.size(),
                     self.num_lbs)).to('cuda') if use_gpu else torch.zeros(
                         (*tag_seq.size(), self.num_lbs))
                clf_logits = clf_logits.scatter(-1, tag_seq.unsqueeze(-1), 1)
                return clf_logits
            if (self.task_type == 'sentsim'
                    and self.task_params.setdefault('sentsim_func', None)
                    and self.task_params['sentsim_func'] !=
                    self.task_params['ymode']):
                return 1 - clf_logits.view(-1, self.num_lbs)
            return clf_logits.view(-1, self.num_lbs)
        if self.crf:
            clf_loss = -self.crf(
                clf_logits.view(input_ids.size()[0], -1, self.num_lbs),
                mask.long())
        elif self.task_type == 'mltc-clf' or self.task_type == 'entlmnt' or self.task_type == 'nmt':
            loss_func = nn.CrossEntropyLoss(weight=weights, reduction='none')
            clf_loss = loss_func(clf_logits.view(-1, self.num_lbs),
                                 labels.view(-1))
        elif self.task_type == 'mltl-clf':
            loss_func = nn.BCEWithLogitsLoss(
                pos_weight=10 * weights if weights is not None else None,
                reduction='none')
            clf_loss = loss_func(clf_logits.view(-1, self.num_lbs),
                                 labels.view(-1, self.num_lbs).float())
        elif self.task_type == 'sentsim':
            from util import config as C
            loss_cls = C.RGRSN_LOSS_MAP[self.task_params.setdefault(
                'loss', 'contrastive')]
            loss_func = loss_cls(
                reduction='none',
                x_mode=C.SIM_FUNC_MAP.setdefault(
                    self.task_params['sentsim_func'], 'dist'),
                y_mode=self.task_params.setdefault('ymode', 'sim')
            ) if self.task_params.setdefault(
                'sentsim_func', None
            ) and self.task_params['sentsim_func'] != 'concat' else nn.MSELoss(
                reduction='none')
            clf_loss = loss_func(clf_logits.view(-1), labels.view(-1))
        return clf_loss

    def _filter_vocab(self):
        pass

    @classmethod
    def callback_update_w2v_model(cls, model):
        def _callback(config):
            from util import config as C
            setattr(config, 'w2v_model', model.w2v_model)
            config.delayed_update(
                C.Configurable.PREDEFINED_MODEL_CONFIG_DELAYED_UPDATES[
                    config.model])

        return _callback

    @classmethod
    def callback_update_elmo_config(cls, model):
        def _callback(config):
            from util import config as C
            setattr(config, 'lm_config', model.lm_config)
            config.delayed_update(
                C.Configurable.PREDEFINED_MODEL_CONFIG_DELAYED_UPDATES[
                    config.model])

        return _callback
Beispiel #13
0
class KnowledgeEnhancedSlotTaggingModel(Model):
    
    def __init__(self, 
                 vocab: Vocabulary,
                 bert_embedder: Optional[PretrainedBertEmbedder] = None,
                 encoder: Optional[Seq2SeqEncoder] = None,
                 dropout: Optional[float] = None,
                 use_crf: bool = True) -> None:
        super().__init__(vocab)

        if bert_embedder:
            self.use_bert = True
            self.bert_embedder = bert_embedder
        else:
            self.use_bert = False
            self.basic_embedder = BasicTextFieldEmbedder({
                "tokens": Embedding(vocab.get_vocab_size(namespace="tokens"), 1024)
            })
            self.rnn = Seq2SeqEncoder.from_params(Params({     
                "type": "lstm",
                "input_size": 1024,
                "hidden_size": 512,
                "bidirectional": True,
                "batch_first": True
            }))

        self.encoder = encoder

        if encoder:
            hidden2tag_in_dim = encoder.get_output_dim()
        else:
            hidden2tag_in_dim = bert_embedder.get_output_dim()
        self.hidden2tag = TimeDistributed(torch.nn.Linear(
            in_features=hidden2tag_in_dim,
            out_features=vocab.get_vocab_size("labels")))
        
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = None
        
        self.use_crf = use_crf
        if use_crf:
            crf_constraints = allowed_transitions(
                constraint_type="BIO",
                labels=vocab.get_index_to_token_vocabulary("labels")
            )
            self.crf = ConditionalRandomField(
                num_tags=vocab.get_vocab_size("labels"),
                constraints=crf_constraints,
                include_start_end_transitions=True
            )
        
        self.f1 = SpanBasedF1Measure(vocab, 
                                     tag_namespace="labels",
                                     ignore_classes=["news/type","negation",
                                                     "demonstrative_reference",
                                                     "timer/noun","timer/attributes"],
                                     label_encoding="BIO")

    def forward(self,
                sentence: Dict[str, torch.Tensor],
                wordnet: Dict[str, torch.Tensor] = None,
                slot_labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        """
        Return a Dict (str -> torch.Tensor), which contains fields:
            mask - the mask matrix of ``sentence``, shape: (batch_size, seq_length)
            embeddings - the embedded tokens, shape: (batch_size, seq_length, embed_size)
            encoder_out - the output of contextual encoder, shape: (batch_size, seq_length, num_features)
            tag_logits - the output of tag projection layer, shape: (batch_size, seq_length, num_tags)
            predicted_tags - the output of CRF layer (use viterbi algorithm to obtain best paths),
                             shape: (batch_size, seq_length)
        """
        # print("bert(token piece ids) shape:", sentence["bert"].shape)
        # print("bert-offsets shape:", sentence["bert-offsets"].shape)
        # print("bert-type-ids shape:", sentence["bert-type-ids"].shape)
        # print("slot-labels shape:", slot_labels.shape)
        # bert_tokenizer = BertTokenizer.from_pretrained("/home1/yym2019/downloads/word-embeddings/bert-large-uncased/vocab.txt")
        # print("bert wordpieces:", bert_tokenizer.convert_ids_to_tokens([tensor.item() for tensor in sentence["bert"][1]]))
        # exit()

        output = {}

        mask = get_text_field_mask(sentence)
        output["mask"] = mask
        # print("mask shape:", mask.shape)
        
        if self.use_bert:
            embeddings = self.bert_embedder(sentence["bert"], sentence["bert-offsets"], sentence["bert-type-ids"])
            if self.dropout:
                embeddings = self.dropout(embeddings)
            output["embeddings"] = embeddings
            # print("embeddings shape:", embeddings.shape)
        else:
            embeddings = self.basic_embedder(sentence)
            if self.dropout:
                embeddings = self.dropout(embeddings)
            output["embeddings"] = embeddings
            embeddings = self.rnn(embeddings, mask)
            if self.dropout:
                embeddings = self.dropout(embeddings)
            output["rnn_out"] = embeddings
        
        if self.encoder:
            encoder_out = self.encoder(embeddings, mask)
            if self.dropout:
                encoder_out = self.dropout(encoder_out)
            output["encoder_out"] = encoder_out
            # print("encoder out shape:", encoder_out.shape)
        else:
            encoder_out = embeddings
        
        tag_logits = self.hidden2tag(encoder_out)
        output["tag_logits"] = tag_logits
        # print("tag logits shape:", tag_logits.shape)

        if self.use_crf:
            best_paths = self.crf.viterbi_tags(tag_logits, mask)
            predicted_tags = [x for x, y in best_paths]  # get the tags and ignore the score
            output["predicted_tags"] = predicted_tags
        else:
            output["predicted_tags"] = torch.argmax(tag_logits, dim=-1)  # pylint: disable=no-member
        
        if slot_labels is not None:
            if self.use_crf:
                log_likelihood = self.crf(tag_logits, slot_labels, mask)  # returns log-likelihood
                output["loss"] = -1.0 * log_likelihood  # add negative log-likelihood as loss
                
                # Represent viterbi tags as "class probabilities" that we can
                # feed into the metrics
                class_probabilities = tag_logits * 0.
                for i, instance_tags in enumerate(predicted_tags):
                    for j, tag_id in enumerate(instance_tags):
                        class_probabilities[i, j, tag_id] = 1
                self.f1(class_probabilities, slot_labels, mask.float())
            else:
                output["loss"] = sequence_cross_entropy_with_logits(tag_logits, slot_labels, mask)
                self.f1(tag_logits, slot_labels, mask.float())
        
        return output
    
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        matric = self.f1.get_metric(reset)
        return {"precision": matric["precision-overall"],
                "recall": matric["recall-overall"],
                "f1": matric["f1-measure-overall"]}
class AttentiveCrfTagger(Model):
    """
    The ``CrfTagger`` encodes a sequence of text with a ``Seq2SeqEncoder``,
    then uses a Conditional Random Field model to predict a tag for each token in the sequence.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the tokens ``TextField`` we get as input to the model.
    encoder : ``Seq2SeqEncoder``
        The encoder that we will use in between embedding tokens and predicting output tags.
    label_namespace : ``str``, optional (default=``labels``)
        This is needed to compute the SpanBasedF1Measure metric.
        Unless you did something unusual, the default value should be what you want.
    feedforward : ``FeedForward``, optional, (default = None).
        An optional feedforward layer to apply after the encoder.
    label_encoding : ``str``, optional (default=``None``)
        Label encoding to use when calculating span f1 and constraining
        the CRF at decoding time . Valid options are "BIO", "BIOUL", "IOB1", "BMES".
        Required if ``calculate_span_f1`` or ``constrain_crf_decoding`` is true.
    include_start_end_transitions : ``bool``, optional (default=``True``)
        Whether to include start and end transition parameters in the CRF.
    constrain_crf_decoding : ``bool``, optional (default=``None``)
        If ``True``, the CRF is constrained at decoding time to
        produce valid sequences of tags. If this is ``True``, then
        ``label_encoding`` is required. If ``None`` and
        label_encoding is specified, this is set to ``True``.
        If ``None`` and label_encoding is not specified, it defaults
        to ``False``.
    calculate_span_f1 : ``bool``, optional (default=``None``)
        Calculate span-level F1 metrics during training. If this is ``True``, then
        ``label_encoding`` is required. If ``None`` and
        label_encoding is specified, this is set to ``True``.
        If ``None`` and label_encoding is not specified, it defaults
        to ``False``.
    dropout:  ``float``, optional (default=``None``)
    verbose_metrics : ``bool``, optional (default = False)
        If true, metrics will be returned per label class in addition
        to the overall statistics.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 label_namespace: str = "labels",
                 feedforward: Optional[FeedForward] = None,
                 label_encoding: Optional[str] = None,
                 include_start_end_transitions: bool = True,
                 attention=None,
                 constrain_crf_decoding: bool = None,
                 calculate_span_f1: bool = None,
                 dropout: Optional[float] = None,
                 verbose_metrics: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.label_namespace = label_namespace
        self.text_field_embedder = text_field_embedder
        self.num_tags = self.vocab.get_vocab_size(label_namespace)
        self.encoder = encoder
        self._verbose_metrics = verbose_metrics
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = None
        self._feedforward = feedforward

        if feedforward is not None:
            output_dim = feedforward.get_output_dim()
        else:
            output_dim = self.encoder.get_output_dim()
        self.tag_projection_layer = TimeDistributed(Linear(output_dim,
                                                           self.num_tags))

        # if  constrain_crf_decoding and calculate_span_f1 are not
        # provided, (i.e., they're None), set them to True
        # if label_encoding is provided and False if it isn't.
        if constrain_crf_decoding is None:
            constrain_crf_decoding = label_encoding is not None
        if calculate_span_f1 is None:
            calculate_span_f1 = label_encoding is not None

        self.label_encoding = label_encoding
        if constrain_crf_decoding:
            if not label_encoding:
                raise ConfigurationError("constrain_crf_decoding is True, but "
                                         "no label_encoding was specified.")
            labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
            constraints = allowed_transitions(label_encoding, labels)
        else:
            constraints = None

        self.include_start_end_transitions = include_start_end_transitions
        self.crf = ConditionalRandomField(
                self.num_tags, constraints,
                include_start_end_transitions=include_start_end_transitions
        )

        self.metrics = {
                "accuracy": CategoricalAccuracy(),
                "accuracy3": CategoricalAccuracy(top_k=3)
        }
        self.calculate_span_f1 = calculate_span_f1
        if calculate_span_f1:
            if not label_encoding:
                raise ConfigurationError("calculate_span_f1 is True, but "
                                         "no label_encoding was specified.")
            self._f1_metric = SpanBasedF1Measure(vocab,
                                                 tag_namespace=label_namespace,
                                                 label_encoding=label_encoding)

        check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
                               "text field embedding dim", "encoder input dim")
        if feedforward is not None:
            check_dimensions_match(encoder.get_output_dim(), feedforward.get_input_dim(),
                                   "encoder output dim", "feedforward input dim")


        initializer(self)

    @overrides
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                tags: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                # pylint: disable=unused-argument
                **kwargs) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : ``torch.LongTensor``, optional (default = ``None``)
            A torch tensor representing the sequence of integer gold class labels of shape
            ``(batch_size, num_tokens)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence to be tagged under a 'words' key.

        Returns
        -------
        An output dictionary consisting of:

        logits : ``torch.FloatTensor``
            The logits that are the output of the ``tag_projection_layer``
        mask : ``torch.LongTensor``
            The text field mask for the input tokens
        tags : ``List[List[int]]``
            The predicted tags using the Viterbi algorithm.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
        """

        embedded_text_input = self.text_field_embedder(tokens)
        mask = util.get_text_field_mask(tokens)

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, mask)

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        if self._feedforward is not None:
            encoded_text = self._feedforward(encoded_text)


        logits = self.tag_projection_layer(encoded_text)
        best_paths = self.crf.viterbi_tags(logits, mask)

        # Just get the tags and ignore the score.
        predicted_tags = [x for x, y in best_paths]

        output = {"logits": logits, "mask": mask, "tags": predicted_tags}

        if tags is not None:
            # Add negative log-likelihood as loss
            log_likelihood = self.crf(logits, tags, mask)
            output["loss"] = -log_likelihood

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the metrics
            class_probabilities = logits * 0.
            for i, instance_tags in enumerate(predicted_tags):
                for j, tag_id in enumerate(instance_tags):
                    class_probabilities[i, j, tag_id] = 1

            for metric in self.metrics.values():
                metric(class_probabilities, tags, mask.float())
            if self.calculate_span_f1:
                self._f1_metric(class_probabilities, tags, mask.float())
        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]
        return output

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Converts the tag ids to the actual tags.
        ``output_dict["tags"]`` is a list of lists of tag_ids,
        so we use an ugly nested list comprehension.
        """
        output_dict["tags"] = [
                [self.vocab.get_token_from_index(tag, namespace=self.label_namespace)
                 for tag in instance_tags]
                for instance_tags in output_dict["tags"]
        ]

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics_to_return = {metric_name: metric.get_metric(reset) for
                             metric_name, metric in self.metrics.items()}

        if self.calculate_span_f1:
            f1_dict = self._f1_metric.get_metric(reset=reset)
            if self._verbose_metrics:
                metrics_to_return.update(f1_dict)
            else:
                metrics_to_return.update({
                        x: y for x, y in f1_dict.items() if
                        "overall" in x})
        return metrics_to_return
Beispiel #15
0
class SlotTaggingModel(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 bert_embedder: Optional[PretrainedBertEmbedder] = None,
                 encoder: Optional[Seq2SeqEncoder] = None,
                 dropout: Optional[float] = None,
                 use_crf: bool = True,
                 add_random_noise: bool = False,
                 add_attack_noise: bool = False,
                 do_noise_normalization: bool = True,
                 noise_norm: Optional[float] = None,
                 noise_loss_prob: Optional[float] = None,
                 add_noise_for: str = "ov",
                 rnn_after_embeddings: bool = False,
                 open_vocabulary_slots: Optional[List[str]] = None,
                 metrics_for_each_slot_type: bool = False) -> None:
        """
        Params
        ------
        vocab: the allennlp Vocabulary object, will be automatically passed
        bert_embedder: the pretrained BERT embedder. If it is not None, the pretrained BERT
                embedding (parameter fixed) will be used as the embedding layer. Otherwise, a look-up
                embedding matrix will be initialized with the embedding size 1024. The default is None.
        encoder: the contextual encoder used after the embedding layer. If set to None, no contextual
                encoder will be used.
        dropout: the dropout rate, won't be set in all our experiments.
        use_crf: if set to True, CRF will be used at the end of the model (as output layer). Otherwise,
                a softmax layer (with cross-entropy loss) will be used.
        add_random_noise: whether to add random noise to slots. Can not be set simultaneously 
                with add_attack_noise. This setting is used as baseline in our experiments.
        add_attack_noise: whether to add adversarial attack noise to slots. Can not be set simultaneously
                with add_random_noise.
        do_noise_normalization: if set to True, the normalization will be applied to gradients w.r.t. 
                token embeddings. Otherwise, the gradients won't be normalized.
        noise_norm: the normalization norm (L2) applied to gradients.
        noise_loss_prob: the alpha hyperparameter to balance the loss from normal forward and adversarial
                forward. See the paper for more details. Should be set from 0 to 1.
        add_noise_for: if set to ov, the noise will only be applied to open-vocabulary slots. Otherwise,
                the noise will be applied to all slots (both open-vocabulary and normal slots).
        rnn_after_embeddings: if set to True, an additional BiLSTM layer will be applied after the embedding
                layer. Default is False.
        open_vocabulary_slots: the list of open-vocabulary slots. If not set, will be set to open-vocabulary
                slots of Snips dataset by default.
        metrics_for_each_slot_type: whether to log metrics for each slot type. Default is False.
        """
        super().__init__(vocab)

        if bert_embedder:
            self.use_bert = True
            self.bert_embedder = bert_embedder
        else:
            self.use_bert = False
            self.basic_embedder = BasicTextFieldEmbedder({
                "tokens":
                Embedding(vocab.get_vocab_size(namespace="tokens"), 1024)
            })
            self.rnn_after_embeddings = rnn_after_embeddings
            if rnn_after_embeddings:
                self.rnn = Seq2SeqEncoder.from_params(
                    Params({
                        "type": "lstm",
                        "input_size": 1024,
                        "hidden_size": 512,
                        "bidirectional": True,
                        "batch_first": True
                    }))

        self.encoder = encoder

        if encoder:
            hidden2tag_in_dim = encoder.get_output_dim()
        else:
            hidden2tag_in_dim = bert_embedder.get_output_dim()
        self.hidden2tag = TimeDistributed(
            torch.nn.Linear(in_features=hidden2tag_in_dim,
                            out_features=vocab.get_vocab_size("labels")))

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        else:
            self.dropout = None

        self.use_crf = use_crf
        if use_crf:
            crf_constraints = allowed_transitions(
                constraint_type="BIO",
                labels=vocab.get_index_to_token_vocabulary("labels"))
            self.crf = ConditionalRandomField(
                num_tags=vocab.get_vocab_size("labels"),
                constraints=crf_constraints,
                include_start_end_transitions=True)

        # default open_vocabulary slots: for SNIPS dataset
        open_vocabulary_slots = open_vocabulary_slots or [
            "playlist", "entity_name", "poi", "restaurant_name",
            "geographic_poi", "album", "track", "object_name", "movie_name"
        ]
        self.f1 = OVSpecSpanBasedF1Measure(
            vocab,
            tag_namespace="labels",
            ignore_classes=[],
            label_encoding="BIO",
            open_vocabulary_slots=open_vocabulary_slots)

        self.add_random_noise = add_random_noise
        self.add_attack_noise = add_attack_noise
        assert not (add_random_noise and
                    add_attack_noise), "both random and attack noise applied"
        if add_random_noise or add_attack_noise:
            self.do_noise_normalization = do_noise_normalization
            assert noise_norm is not None
            assert noise_loss_prob is not None and 0. <= noise_loss_prob <= 1.
            self.noise_norm = noise_norm
            self.noise_loss_prob = noise_loss_prob
            assert add_noise_for in ["ov", "all"]
            self.ov_noise_only = (add_noise_for == "ov")

        self.metrics_for_each_slot_type = metrics_for_each_slot_type

    def forward(self,
                sentence: Dict[str, torch.Tensor],
                slot_labels: torch.Tensor = None,
                ov_slot_mask: torch.Tensor = None,
                slot_mask: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        """
        Params
        ------
        sentence: a Dict contains tensors of token ids (in "tokens" key) or (If use BERT as embedding
                layer) BERT BPE ids, offsets, segment ids. This parameter is the output of
                TextField.as_tensors(), see ~allennlp.data.fields.text_field.TextField for details.
                Each field should have shape (batch_size, seq_length)
        slot_labels: slot label ids (in BIO format), of shape (batch_size, seq_length)
        ov_slot_mask: binary mask, 1 for tokens of open-vocabulary slots, 0 for otherwise (non-slot tokens
                 and tokens of normal slots). Of shape (batch_size, seq_length)
        slot_mask: binary mask, 1 for tokens of slots (all slots), 0 for non-slot tokens (i.e. the O tag).
                Of shape (batch_size, seq_length)
        
        Return a Dict (str -> torch.Tensor), which contains fields:
                mask - the mask matrix of ``sentence``, shape: (batch_size, seq_length)
                embeddings - the embedded tokens, shape: (batch_size, seq_length, embed_size)
                encoder_out - the output of contextual encoder, shape: (batch_size, seq_length, num_features)
                tag_logits - the output of tag projection layer, shape: (batch_size, seq_length, num_tags)
                predicted_tags - the output of CRF layer (use viterbi algorithm to obtain best paths),
                             shape: (batch_size, seq_length)
        """
        output = {}

        mask = get_text_field_mask(sentence)
        output["mask"] = mask

        if self.use_bert:
            embeddings = self.bert_embedder(sentence["bert"],
                                            sentence["bert-offsets"],
                                            sentence["bert-type-ids"])
            if self.dropout:
                embeddings = self.dropout(embeddings)
            output["embeddings"] = embeddings
        else:
            embeddings = self.basic_embedder(sentence)
            if self.dropout:
                embeddings = self.dropout(embeddings)
            output["embeddings"] = embeddings
            if self.rnn_after_embeddings:
                embeddings = self.rnn(embeddings, mask)
                if self.dropout:
                    embeddings = self.dropout(embeddings)
                output["rnn_out"] = embeddings

        if not self.training:  # when predict or evaluate, no need for adding noise
            output.update(self._inner_forward(embeddings, mask, slot_labels))
        elif not self.add_random_noise and not self.add_attack_noise:  # for baseline
            output.update(self._inner_forward(embeddings, mask, slot_labels))
        else:  # add random noise or attack noise for open-vocabulary slots
            if self.add_random_noise:  # add random noise
                unnormalized_noise = torch.randn(
                    embeddings.shape).to(device=embeddings.device)
            else:  # add attack noise
                normal_loss = self._inner_forward(embeddings, mask,
                                                  slot_labels)["loss"]
                embeddings.retain_grad(
                )  # we need to get gradient w.r.t embeddings
                normal_loss.backward(retain_graph=True)
                unnormalized_noise = embeddings.grad.detach_()
                for p in self.parameters():
                    if p.grad is not None:
                        p.grad.detach_()
                        p.grad.zero_()
            if self.do_noise_normalization:  # do normalization
                norm = unnormalized_noise.norm(p=2, dim=-1)
                normalized_noise = unnormalized_noise / (
                    norm.unsqueeze(dim=-1) + 1e-10)  # add 1e-10 to avoid NaN
            else:  # no normalization
                normalized_noise = unnormalized_noise
            if self.ov_noise_only:  # add noise to open-vocabulary slots only
                ov_slot_noise = self.noise_norm * normalized_noise * ov_slot_mask.unsqueeze(
                    dim=-1).float()
            else:  # add noise to all slots
                ov_slot_noise = self.noise_norm * normalized_noise * slot_mask.unsqueeze(
                    dim=-1).float()
            output["ov_slot_noise"] = ov_slot_noise
            noise_embeddings = embeddings + ov_slot_noise  # semantics decoupling using noise
            normal_sample_loss = self._inner_forward(
                embeddings, mask, slot_labels)["loss"]  # normal forward
            noise_sample_loss = self._inner_forward(
                noise_embeddings, mask,
                slot_labels)["loss"]  # adversarial forward
            loss = normal_sample_loss * (
                1 - self.noise_loss_prob
            ) + noise_sample_loss * self.noise_loss_prob
            output["loss"] = loss
        return output

    def _inner_forward(self, embeddings: torch.Tensor, mask: torch.Tensor,
                       slot_labels: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward from **embedding space** to a loss or predicted-tags.
        """
        output = {}

        if self.encoder:
            encoder_out = self.encoder(embeddings, mask)
            if self.dropout:
                encoder_out = self.dropout(encoder_out)
            output["encoder_out"] = encoder_out
        else:
            encoder_out = embeddings

        tag_logits = self.hidden2tag(encoder_out)
        output["tag_logits"] = tag_logits

        if self.use_crf:
            best_paths = self.crf.viterbi_tags(tag_logits, mask)
            predicted_tags = [x for x, y in best_paths
                              ]  # get the tags and ignore the score
            predicted_score = [y for _, y in best_paths]
            output["predicted_tags"] = predicted_tags
            output["predicted_score"] = predicted_score
        else:
            output["predicted_tags"] = torch.argmax(tag_logits, dim=-1)  # pylint: disable=no-member

        if slot_labels is not None:
            if self.use_crf:
                log_likelihood = self.crf(tag_logits, slot_labels,
                                          mask)  # returns log-likelihood
                output[
                    "loss"] = -1.0 * log_likelihood  # add negative log-likelihood as loss

                # Represent viterbi tags as "class probabilities" that we can
                # feed into the metrics
                class_probabilities = tag_logits * 0.
                for i, instance_tags in enumerate(predicted_tags):
                    for j, tag_id in enumerate(instance_tags):
                        class_probabilities[i, j, tag_id] = 1
                self.f1(class_probabilities, slot_labels, mask.float())
            else:
                output["loss"] = sequence_cross_entropy_with_logits(
                    tag_logits, slot_labels, mask)
                self.f1(tag_logits, slot_labels, mask.float())

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metric = self.f1.get_metric(reset)

        results = {}

        if self.metrics_for_each_slot_type:
            results.update(metric)
        else:
            results.update({
                "precision": metric["precision-overall"],
                "precision-ov": metric["precision-ov"],
                "recall": metric["recall-overall"],
                "recall-ov": metric["recall-ov"],
                "f1": metric["f1-measure-overall"],
                "f1-ov": metric["f1-measure-ov"]
            })

        return results
Beispiel #16
0
class CWSModel(Model):
    def __init__(self, model_path, vocab: Vocabulary):
        super().__init__(vocab)
        self.pretrained_tokenizer = BertForPreTraining.from_pretrained(
            model_path)
        config = BertConfig.from_pretrained(model_path)
        bert_model = BertForPreTraining(config)
        self.bert = bert_model.bert
        tags = vocab.get_index_to_token_vocabulary("tags")
        num_tags = len(tags)
        constraints = allowed_transitions(constraint_type="BMES", labels=tags)
        self.projection = torch.nn.Linear(768, num_tags)
        self.crf = ConditionalRandomField(num_tags=num_tags,
                                          constraints=constraints,
                                          include_start_end_transitions=False)

    def forward(self,
                tokens,
                attention_mask,
                token_type_ids,
                length,
                tags=None,
                metadata=None) -> Dict[str, torch.Tensor]:
        """

        :param tokens:
        :param attention_mask:
        :param token_type_ids:
        :param length: TODO (batch, 1) or (batch, )? 这个没啥,最后加一个view(-1) or view(-1, 1)就行。
        :param tags:
        :param metadata:
        :return:
        """
        output_dict = dict()
        input_ids = tokens['tokens']['tokens']
        bert_outputs = self.bert(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 token_type_ids=token_type_ids)
        bert_outputs = bert_outputs[0]  # (batch, sequence, hidden_size)

        # bert_outputs包括了特殊的两个符号CLS, SEP, 并且由于之前tag和attention_mask参照tokens进行了处理。
        # 所以bert_outputs, tag, attention_mask应该对这两个符号进行处理掉。
        # 在allennlp的处理中,输入到crf中第一个位置的tag(tag[0])是必定会处理的,所以这里不能输入CLS对应的tag
        # 后面位置的tag可以通过mask进行mask掉。
        # 但是在predict阶段,还需要手动移出最后一位(根据length的长度)
        bert_outputs = bert_outputs[:, 1:, :]
        logits = self.projection(bert_outputs)
        log_likelihood = torch.nn.functional.log_softmax(logits, -1)
        attention_mask = attention_mask[:, 1:]
        if tags is not None:
            tags = tags[:, 1:]

            loss = -self.crf(log_likelihood, tags, attention_mask)
            output_dict['loss'] = loss

        # 运行viterbi解码
        best_path = self.crf.viterbi_tags(logits, attention_mask)
        output_dict['best_path'] = best_path

        output_dict['metadata'] = metadata

        output_dict['input_ids'] = input_ids[:, 1:]  # 已经进行了切分
        output_dict['attention_mask'] = attention_mask
        if tags is not None:
            output_dict['tags'] = tags
        best_path = [
            path[0][:mask.sum()]
            for path, mask in zip(best_path, attention_mask)
        ]
        output_dict['best_path'] = best_path
        return output_dict

    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        text_predict_tags = [[
            self.vocab.get_token_from_index(idx, 'tags') for idx in path
        ] for path in output_dict['best_path']]
        output_dict.update({'text_predict_tags': text_predict_tags})
        return output_dict
Beispiel #17
0
class RNNTagger(nn.Module):
    def __init__(self,
                 n_vocab,
                 unigram_embed_size,
                 rnn_unit_type,
                 rnn_bidirection,
                 rnn_batch_first,
                 rnn_n_layers,
                 rnn_hidden_size,
                 mlp_n_layers,
                 mlp_hidden_size,
                 n_labels,
                 use_crf=True,
                 crf_top_k=1,
                 embed_dropout=0.0,
                 rnn_dropout=0.0,
                 mlp_dropout=0.0,
                 pretrained_unigram_embed_size=0,
                 pretrained_embed_usage=ModelUsage.NONE):
        super(RNNTagger, self).__init__()
        self.n_vocab = n_vocab
        self.unigram_embed_size = unigram_embed_size

        self.rnn_unit_type = rnn_unit_type
        self.rnn_bidirection = rnn_bidirection
        self.rnn_batch_first = rnn_batch_first
        self.rnn_n_layers = rnn_n_layers
        self.rnn_hidden_size = rnn_hidden_size

        self.mlp_n_layers = mlp_n_layers
        self.mlp_hidden_size = mlp_hidden_size
        self.n_labels = n_labels
        self.use_crf = use_crf
        self.crf_top_k = crf_top_k

        self.embed_dropout = embed_dropout
        self.rnn_dropout = rnn_dropout
        self.mlp_dropout = mlp_dropout

        self.pretrained_unigram_embed_size = pretrained_unigram_embed_size
        self.pretrained_embed_usage = pretrained_embed_usage

        self.unigram_embed = None
        self.pretrained_unigram_embed = None
        self.rnn = None
        self.mlp = None
        self.crf = None
        self.cross_entropy_loss = None

        print('### Parameters', file=sys.stderr)

        # embeddings layer(s)

        print('# Embedding dropout ratio={}'.format(self.embed_dropout),
              file=sys.stderr)
        self.unigram_embed, self.pretrained_unigram_embed = models.util.construct_embeddings(
            n_vocab, unigram_embed_size, pretrained_unigram_embed_size,
            pretrained_embed_usage)
        if self.pretrained_embed_usage != ModelUsage.NONE:
            print('# Pretrained embedding usage: {}'.format(
                self.pretrained_embed_usage),
                  file=sys.stderr)
        print('# Unigram embedding matrix: W={}'.format(
            self.unigram_embed.weight.shape),
              file=sys.stderr)
        embed_size = self.unigram_embed.weight.shape[1]
        if self.pretrained_unigram_embed is not None:
            if self.pretrained_embed_usage == ModelUsage.CONCAT:
                embed_size += self.pretrained_unigram_embed_size
                print('# Pretrained unigram embedding matrix: W={}'.format(
                    self.pretrained_unigram_embed.weight.shape),
                      file=sys.stderr)

        # recurrent layers

        self.rnn_unit_type = rnn_unit_type
        self.rnn = models.util.construct_RNN(unit_type=rnn_unit_type,
                                             embed_size=embed_size,
                                             hidden_size=rnn_hidden_size,
                                             n_layers=rnn_n_layers,
                                             batch_first=rnn_batch_first,
                                             dropout=rnn_dropout,
                                             bidirectional=rnn_bidirection)
        rnn_output_size = rnn_hidden_size * (2 if rnn_bidirection else 1)

        # MLP

        print('# MLP', file=sys.stderr)
        mlp_in = rnn_output_size
        self.mlp = MLP(input_size=mlp_in,
                       hidden_size=mlp_hidden_size,
                       n_layers=mlp_n_layers,
                       output_size=n_labels,
                       dropout=mlp_dropout,
                       activation=nn.Identity)

        # Inference layer (CRF/softmax)

        if self.use_crf:
            self.crf = ConditionalRandomField(n_labels)
            print('# CRF cost: {}'.format(self.crf.transitions.shape),
                  file=sys.stderr)
        else:
            self.softmax_cross_entropy = nn.CrossEntropyLoss()

    """
    us: batch of unigram sequences
    ls: batch of label sequences
    """

    # unigram and label
    def forward(self, us, ls=None, calculate_loss=True, decode=False):
        lengths = self.extract_lengths(us)
        us, ls = self.pad_features(us, ls)
        xs = self.extract_features(us)
        rs = self.rnn_output(xs, lengths)
        ys = self.mlp(rs)
        loss, ps = self.predict(ys,
                                ls=ls,
                                lengths=lengths,
                                calculate_loss=calculate_loss,
                                decode=decode)
        return loss, ps

    def extract_lengths(self, ts):
        device = ts[0].device
        return torch.tensor([t.shape[0] for t in ts], device=device)

    def pad_features(self, us, ls):
        batch_first = self.rnn_batch_first
        us = pad_sequence(us, batch_first=batch_first)
        ls = pad_sequence(ls, batch_first=batch_first) if ls else None

        return us, ls

    def extract_features(self, us):
        xs = []

        for u in us:
            ue = self.unigram_embed(u)
            if self.pretrained_unigram_embed is not None:
                if self.pretrained_embed_usage == ModelUsage.ADD:
                    pe = self.pretrained_unigram_embed(u)
                    ue = ue + pe
                elif self.pretrained_embed_usage == ModelUsage.CONCAT:
                    pe = self.pretrained_unigram_embed(u)
                    ue = torch.cat((ue, pe), 1)
            ue = F.dropout(ue, p=self.embed_dropout)
            xe = ue
            xs.append(xe)

        if self.rnn_batch_first:
            xs = torch.stack(xs, dim=0)
        else:
            xs = torch.stack(xs, dim=1)

        return xs

    def rnn_output(self, xs, lengths=None):
        if self.rnn_unit_type == 'lstm':
            hs, (hy, cy) = self.rnn(xs, lengths)
        else:
            hs, hy = self.rnn(xs)
        return hs

    def predict(self,
                rs,
                ls=None,
                lengths=None,
                calculate_loss=True,
                decode=False):
        if self.crf:
            return self.predict_crf(rs, ls, lengths, calculate_loss, decode)
        else:
            return self.predict_softmax(rs, ls, calculate_loss)

    def predict_softmax(self, ys, ls=None, calculate_loss=True):
        ps = []
        loss = torch.tensor(0, dtype=torch.float, device=ys.device)
        if ls is None:
            ls = [None] * len(ys)
        for y, l in zip(ys, ls):
            if calculate_loss:
                loss += self.softmax_cross_entropy(y, l)
            ps.append([torch.argmax(yi.data) for yi in y])

        return loss, ps

    def predict_crf(self,
                    hs,
                    ls=None,
                    lengths=None,
                    calculate_loss=True,
                    decode=False):
        device = hs.device
        if lengths is None:
            lengths = torch.tensor([h.shape[0] for h in hs], device=device)
        mask = get_mask_from_sequence_lengths(lengths, max_length=max(lengths))
        if not decode or self.crf_top_k == 1:
            ps = self.crf.viterbi_tags(hs, mask)
            ps, score = zip(*ps)
        else:
            ps = []
            psks = self.crf.viterbi_tags(hs, mask, top_k=self.crf_top_k)
            for psk in psks:
                psk, score = zip(*psk)
                ps.append(psk)

        if calculate_loss:
            log_likelihood = self.crf(hs, ls, mask)
            loss = -1 * log_likelihood / len(lengths)
        else:
            loss = torch.tensor(np.array(0), dtype=torch.float, device=device)

        return loss, ps

    def decode(self, us):
        with torch.no_grad():
            _, ps = self.forward(us, calculate_loss=False, decode=True)
        return ps
Beispiel #18
0
class SpanScorerCRF(nn.Module):
    '''
    Span extractor
    '''
    def __init__(
        self,
        input_dim,
        num_tags,
        low_val=-5,
        high_val=5,
        incl_start_end=True,
        name=None,
    ):
        super(SpanScorerCRF, self).__init__()

        self.input_dim = input_dim
        self.num_tags = num_tags
        self.low_val = low_val
        self.high_val = high_val
        self.incl_start_end = incl_start_end
        self.name = name

        self.span_to_seq, self.seq_to_span = label_map(num_tags)

        self.num_tags_seq = len(self.seq_to_span)
        self.num_tags_span = len(self.span_to_seq)

        # Linear projection layer
        self.projection = nn.Linear(input_dim, self.num_tags_seq)

        # Create event-specific CRF
        self.crf = ConditionalRandomField( \
                        num_tags = self.num_tags_seq,
                        include_start_end_transitions = incl_start_end)

    def forward(self,
                seq_tensor,
                seq_mask,
                span_map,
                span_indices,
                verbose=False):
        '''
        Calculate logits
        '''
        # Dimensionality
        batch_size, max_seq_len, input_dim = tuple(seq_tensor.shape)

        # Project input tensor sequence to logits
        seq_scores = self.projection(seq_tensor)
        '''
        Decoding sequence tags
        '''

        # Viterbi decode
        best_paths = self.crf.viterbi_tags( \
                                        logits = seq_scores,
                                        mask = seq_mask)
        seq_pred, score = zip(*best_paths)
        seq_pred = list(seq_pred)
        '''
        Convert sequence tags to span predictions
        '''
        # Get spans from sequence tags
        #   Converts list of list of predicted label indices to
        #   tensor of size (batch_size, num_spans)
        span_pred = seq_tags_to_spans( \
                                seq_tags = seq_pred,
                                span_map = span_map,
                                seq_tag_map = self.seq_to_span)

        span_pred = span_pred.to(seq_tensor.device)

        # Get scores from labels
        span_pred = F.one_hot(span_pred,
                              num_classes=self.num_tags_span).float()

        #print('crf seq  pos: ', sum([int(w > 0) for W in seq_pred for w in W]))
        #print('crf span pos: ', (span_pred > 0).sum().tolist())
        #print(span_pred)
        return (seq_scores, span_pred)

    def loss(self, span_labels, seq_scores, seq_mask, span_map):

        batch_size, max_len, embed_dim = tuple(seq_scores.shape)



        seq_labels = get_seq_labels( \
                        span_labels = span_labels,
                        span_map = span_map,
                        span_to_seq = self.span_to_seq,
                        max_len = max_len)

        seq_mask[:, 0] = True

        # Get loss (negative log likelihood)
        loss = -self.crf( \
                            inputs = seq_scores,
                            tags = seq_labels,
                            mask = seq_mask)
        #print('loss', loss)

        return loss
Beispiel #19
0
class BiLSTMCRF(Model):

    def __init__(self, vocab: Vocabulary, embedding_dim=300, embedder_type=None, bert_trainable=True, **kwargs):
        super().__init__(vocab)
        for k in kwargs:
            self.__setattr__(k, kwargs[k])

        text_field_embedder = get_embeddings(embedder_type, self.vocab, embedding_dim, bert_trainable)
        embedding_dim = text_field_embedder.get_output_dim()

        encoder = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(embedding_dim, self.num_rnn_units, batch_first=True, bidirectional=True, dropout=self.dropout_rate))

        self.label_namespace = label_namespace = 'ner_bio_labels'
        self.num_tags = self.vocab.get_vocab_size(label_namespace)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder
        self.dropout = torch.nn.Dropout(self.dropout_rate)

        output_dim = self.encoder.get_output_dim()
        self.tag_projection_layer = TimeDistributed(Linear(output_dim,
                                                           self.num_tags))

        self.label_encoding = label_encoding = 'BIO'
        labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
        constraints = allowed_transitions(self.label_encoding, labels)

        self.include_start_end_transitions = True
        self.crf = ConditionalRandomField(
            self.num_tags, constraints,
            include_start_end_transitions=True
        )

        self._f1_metric = SpanBasedF1Measure(self.vocab,
                                             tag_namespace=label_namespace,
                                             label_encoding=label_encoding)
        self._verbose_metrics = False

    @overrides
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                ner_bio: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                # pylint: disable=unused-argument
                **kwargs) -> Dict[str, torch.Tensor]:
        embedded_text_input = self.text_field_embedder(text)
        mask = util.get_text_field_mask(text)
        tags = ner_bio
        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        lstm_feature = self.encoder(embedded_text_input, mask)

        if self.dropout:
            lstm_feature = self.dropout(lstm_feature)
        logits = self.tag_projection_layer(lstm_feature)

        best_paths = self.crf.viterbi_tags(logits, mask)

        # Just get the tags and ignore the score.
        predicted_tags = [x for x, y in best_paths]

        output = {"logits": logits, "mask": mask, "tags": predicted_tags}

        if tags is not None:
            # Add negative log-likelihood as loss
            log_likelihood = self.crf(logits, tags, mask)

            output["loss"] = -log_likelihood
            batch_size = tags.shape[0]
            output['loss'] /= batch_size

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the metrics
            class_probabilities = logits * 0.
            for i, instance_tags in enumerate(predicted_tags):
                for j, tag_id in enumerate(instance_tags):
                    class_probabilities[i, j, tag_id] = 1

            self._f1_metric(class_probabilities, tags, mask.float())
        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]
        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics_to_return = {}
        f1_dict = self._f1_metric.get_metric(reset=reset)
        if self._verbose_metrics:
            metrics_to_return.update(f1_dict)
        else:
            metrics_to_return.update({
                x: y for x, y in f1_dict.items() if
                "overall" in x})
        return metrics_to_return

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Converts the tag ids to the actual tags.
        ``output_dict["tags"]`` is a list of lists of tag_ids,
        so we use an ugly nested list comprehension.
        """
        output_dict["tags"] = [
            [self.vocab.get_token_from_index(tag, namespace=self.label_namespace)
             for tag in instance_tags]
            for instance_tags in output_dict["tags"]
        ]

        return output_dict