Ejemplo n.º 1
0
    def forward(self,
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                labels: torch.IntTensor = None,
                **kwargs):
        text_embeddings = self._lexical_dropout(self.embedder(text))

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        spans = F.relu(spans.float()).long()

        span_embeddings = self._span_extractor(text_embeddings,
                                               spans,
                                               span_indices_mask=span_mask)

        span_scores = self.feedforward_scorer(span_embeddings)

        span_scores = span_scores.squeeze(-1)
        span_scores += span_mask.log()
        span_scores = span_scores.sigmoid()
        topk_idx = torch.topk(span_scores,
                              int(self.keep_rate * spans.shape[1]))[-1]
        predict_true = span_scores.new_zeros(span_scores.shape).scatter_(
            1, topk_idx, 1).bool()
        is_entity = (labels != 0).float()
        span_scores = span_scores.reshape(-1)
        is_entity = is_entity.reshape(-1)
        loss = self.loss(span_scores, is_entity)

        predict_true_flatten = predict_true.reshape(-1)
        predict_true_flatten = predict_true_flatten.unsqueeze(-1)
        predict_false_flatten = ~predict_true_flatten
        predict = torch.cat([predict_false_flatten, predict_true_flatten], -1)
        self._metric_f1(predict, is_entity, mask=span_mask.reshape(-1))

        predict_true |= labels.bool()
        output_dict = {"loss": loss, "predict_true": predict_true}
        return output_dict
Ejemplo n.º 2
0
    def forward(self,
                sent_tokens,
                pats_tokens,
                sent_tokens_mask,
                pats_tokens_mask,
                sent: torch.IntTensor,
                mid: torch.IntTensor,
                rel_label: torch.IntTensor,
                pat_label: torch.IntTensor,
                pattern_rels: torch.IntTensor,
                pats: torch.IntTensor,
                weights: torch.FloatTensor,
                is_train=True):
        """
        8是batchsize
        sent: 8 * 110, 现在还不是embedding
        mid: 8 * 110
        rel: 8 (每个元素小于rel_nums 3)
        pat: 8, (每个数字代表对应的pattern_id), 如果无对应的,就是-1


        patterns:
        pattern_rels: pattern_num * rel_num, 每条pattern对应的relation
        pats_token: pattern_num * pat_token_len, 每条patterns的embedding
        weights: pattern的权重, 维度: patterns_num
        """
        self.is_train = is_train
        device = self.config.device

        sent = torch.from_numpy(sent).long().to(device)
        mid = torch.from_numpy(mid).long().to(device)
        rel_label = torch.from_numpy(rel_label).long().to(device)
        pat_label = torch.from_numpy(pat_label).long().to(device)
        pattern_rels = torch.from_numpy(pattern_rels).float().to(device)
        pats = torch.from_numpy(pats).long().to(device)
        weights = torch.from_numpy(weights).float().to(device)
        sent_tokens = torch.from_numpy(sent_tokens).long().to(device)
        pats_tokens = torch.from_numpy(pats_tokens).long().to(device)
        sent_tokens_mask = torch.from_numpy(sent_tokens_mask).bool().to(device)
        pats_tokens_mask = torch.from_numpy(pats_tokens_mask).bool().to(device)

        rel_label = torch.argmax(rel_label, -1)
        pattern_rels_label = torch.argmax(pattern_rels, -1)

        sent_mask = sent.bool()
        sent_len = torch.sum(sent_mask, dim=1)
        sent_max_len = torch.max(sent_len)
        sent_mask = sent_mask[:, :sent_max_len]
        sent = sent[:, :sent_max_len]

        mid_mask = mid.bool()
        mid_len = torch.sum(mid_mask, dim=1)
        mid_max_len = torch.max(mid_len)
        mid_mask = mid_mask[:, :mid_max_len]
        mid = mid[:, :mid_max_len]

        pat_mask = pats.bool()
        pat_len = torch.sum(pat_mask, dim=1)
        pat_max_len = torch.max(pat_len)
        pat_mask = pat_mask[:, :pat_max_len]
        pat = pats[:, :pat_max_len]

        sent_embedding = self.get_embedding(sent)
        mid_embedding = self.get_embedding(mid)
        pat_embedding = self.get_embedding(pats)

        # encoder

        sent_d = self.bert(sent_tokens,
                           attention_mask=sent_tokens_mask)[0][:, 0, :]
        pat_d = self.bert_no_grad(pats_tokens,
                                  attention_mask=pats_tokens_mask)[0][:, 0, :]

        # similarity
        sim, pat_sim = self.att_match(mid_embedding, pat_embedding, mid_mask,
                                      pat_mask, self.keep_prob, self.is_train)

        neg_idxs = torch.matmul(pattern_rels,
                                torch.transpose(pattern_rels, 1, 0))
        pat_pos = torch.square(
            torch.max(self.config.tau - pat_sim, torch.zeros_like(pat_sim)))
        pat_pos = torch.max(pat_pos - (1 - neg_idxs) * 1e30, dim=1)[0]
        pat_neg = torch.square(torch.max(pat_sim, torch.zeros_like(pat_sim)))
        pat_neg = torch.max(pat_neg - 1e30 * neg_idxs, dim=1)[0]
        l_sim = torch.sum(weights * (pat_pos + pat_neg), dim=0)

        logit = self.fc_sent2rel(sent_d)
        pred = F.softmax(logit, dim=1)

        if self.is_train is True:

            l_a = F.cross_entropy(logit[:self.config.gt_batch_size],
                                  rel_label[:self.config.gt_batch_size])

            xsim = sim[self.config.gt_batch_size:]
            # xsim = xsim.detach()
            # xsim.requires_grad = False
            pseudo_rel = pattern_rels_label[torch.argmax(xsim, dim=1)]
            bound = torch.max(xsim, dim=1)[0]
            weight = F.softmax(10 * bound, dim=0)

            l_u = torch.sum(weight *
                            F.cross_entropy(logit[self.config.gt_batch_size:],
                                            pseudo_rel,
                                            reduction='none'))

            pat2rel = self.fc_pat2rel(pat_d)
            pat2rel_pred = F.softmax(pat2rel, dim=1)
            l_pat = F.cross_entropy(pat2rel_pred, pattern_rels_label)
            loss = l_a + self.config.alpha * l_pat + self.config.gamma * l_u + self.config.beta * l_sim
            # loss = l_a + self.config.alpha * l_pat + self.config.beta * l_u
        else:
            loss = 0.0

        preds = torch.argmax(pred, dim=1)
        val = torch.sum((0 - torch.log(torch.clamp(pred, 1e-5, 1.0))) * pred,
                        dim=1)
        golds = rel_label

        return golds, preds, val, loss
Ejemplo n.º 3
0
Archivo: spart.py Proyecto: MSLars/mare
    def forward(
            self,  # type: ignore
            tokens: TextFieldTensors,
            spans: torch.IntTensor,
            ner_labels: torch.IntTensor = None,
            rel_span_indices: torch.IntTensor = None,
            rel_labels: torch.IntTensor = None,
            span_masks: torch.IntTensor = None,
            relation_masks: torch.IntTensor = None,
            rels_sample_masks: torch.BoolTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, Dict]:
        embedded_text_input = self._text_field_embedder(tokens)
        try:
            entity_ctx = self._bert(
                tokens["tokens"]["token_ids"]).last_hidden_state[:, 0, :]
        except AttributeError:
            entity_ctx = self._bert(tokens["tokens"]["token_ids"])[0][:, 0, :]
        embedded_text_input = torch.cat(
            (entity_ctx.unsqueeze(1), embedded_text_input), dim=1)
        batch_size = embedded_text_input.shape[0]

        entity_sizes = spans[:, :, 1] - spans[:, :, 0] + 1

        size_embeddings = self.size_embeddings(entity_sizes)

        entity_clf, entity_spans_pool = self._classify_entities(
            embedded_text_input, span_masks, size_embeddings, entity_ctx)

        # TODO If we have no gold entities, we cannot specify relation candidates!
        # entity_max_logits_index = entity_clf.max(dim=2).indices
        # relation_candidates = []
        # relation_masks = []
        # for batch in range(entity_max_logits_index.shape[0]):
        #
        #     entity_indices = entity_max_logits_index[batch].nonzero(as_tuple=True)[0]
        #
        #     new_candidates = list(itertools.permutations(entity_indices.tolist(), 2))
        #
        #     for nc in new_candidates:
        #
        #         start_entity_span = tuple(spans[batch][nc[0]].tolist())
        #         end_entity_span = tuple(spans[batch][nc[1]].tolist())
        #
        #         relation_masks += [create_rel_mask(start_entity_span, end_entity_span, embedded_text_input.shape[1])]
        #
        #     relation_candidates += []

        #TODO  wir haben zur evaluation KEINE Label, die zu diesen labeln passen, wir müssen die von Span Labeling usw. wieder nutzen!
        # rel_span_indices = torch.tensor(relation_candidates, device=entity_clf.device)

        # classify relations
        if rel_labels is None:
            ctx_size = embedded_text_input.shape[1]

            entity_sample_masks = torch.ones((batch_size, entity_clf.shape[1]))

            rel_span_indices, relation_masks, rel_sample_masks = self._filter_spans(
                entity_clf, spans, entity_sample_masks, ctx_size)

            rel_sample_masks = rel_sample_masks.float().unsqueeze(-1)
            h_large = embedded_text_input.unsqueeze(1).repeat(
                1, max(min(rel_span_indices.shape[1], self._max_pairs), 1), 1,
                1)
            rel_clf = torch.zeros(
                [batch_size, rel_span_indices.shape[1],
                 self._relation_types]).to(self.rel_classifier.weight.device)
        else:
            h_large = embedded_text_input.unsqueeze(1).repeat(
                1, max(min(rel_span_indices.shape[1], self._max_pairs), 1), 1,
                1)
            rel_clf = torch.zeros(
                [batch_size, rel_span_indices.shape[1],
                 self._relation_types]).to(self.rel_classifier.weight.device)

        # obtain relation logits
        # chunk processing to reduce memory usage
        for i in range(0, rel_span_indices.shape[1], self._max_pairs):
            # classify relation candidates
            chunk_rel_logits = self._classify_relations(
                entity_spans_pool, size_embeddings, rel_span_indices,
                relation_masks, h_large, i)
            chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf

        converted_relations = []

        for batch in range(batch_size):
            batch_pred_entities, batch_pred_relations = self.convert_predictions(
                entity_clf[batch].unsqueeze(0),
                rel_clf[batch].unsqueeze(0),
                rel_span_indices[batch].unsqueeze(0),
                spans[batch].unsqueeze(0),
                entity_sample_masks[batch].unsqueeze(0),
                self._rel_filter_threshold,
            )

            batch_converted_relations = []
            for pred_relation in batch_pred_relations[0]:

                h_name, t_name = sorted(relation_args_names[pred_relation[2]])
                converted_relation = {
                    "name":
                    pred_relation[2],
                    "ents": [
                        {
                            "name": h_name,
                            "start": pred_relation[0][0],
                            "end": pred_relation[0][1],
                        },
                        {
                            "name": t_name,
                            "start": pred_relation[1][0],
                            "end": pred_relation[1][1],
                        },
                    ]
                }

                batch_converted_relations += [converted_relation]

            converted_relations += [batch_converted_relations]

        if ner_labels and rel_labels:

            batch_loss = self.compute_loss(entity_logits=entity_clf,
                                           rel_logits=rel_clf,
                                           rel_types=rel_labels,
                                           entity_types=ner_labels,
                                           rel_sample_masks=rels_sample_masks)

            self._f1_entities(entity_clf, ner_labels)
            #self._f1_relation(rel_clf, rel_labels.bool())
            self._f1_relation(rel_clf.squeeze(), rel_labels.bool().squeeze())

            return {"loss": batch_loss}

        return {"relations": converted_relations}