def forward(self, text: Dict[str, torch.LongTensor], spans: torch.IntTensor, labels: torch.IntTensor = None, **kwargs): text_embeddings = self._lexical_dropout(self.embedder(text)) # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() spans = F.relu(spans.float()).long() span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask) span_scores = self.feedforward_scorer(span_embeddings) span_scores = span_scores.squeeze(-1) span_scores += span_mask.log() span_scores = span_scores.sigmoid() topk_idx = torch.topk(span_scores, int(self.keep_rate * spans.shape[1]))[-1] predict_true = span_scores.new_zeros(span_scores.shape).scatter_( 1, topk_idx, 1).bool() is_entity = (labels != 0).float() span_scores = span_scores.reshape(-1) is_entity = is_entity.reshape(-1) loss = self.loss(span_scores, is_entity) predict_true_flatten = predict_true.reshape(-1) predict_true_flatten = predict_true_flatten.unsqueeze(-1) predict_false_flatten = ~predict_true_flatten predict = torch.cat([predict_false_flatten, predict_true_flatten], -1) self._metric_f1(predict, is_entity, mask=span_mask.reshape(-1)) predict_true |= labels.bool() output_dict = {"loss": loss, "predict_true": predict_true} return output_dict
def forward(self, sent_tokens, pats_tokens, sent_tokens_mask, pats_tokens_mask, sent: torch.IntTensor, mid: torch.IntTensor, rel_label: torch.IntTensor, pat_label: torch.IntTensor, pattern_rels: torch.IntTensor, pats: torch.IntTensor, weights: torch.FloatTensor, is_train=True): """ 8是batchsize sent: 8 * 110, 现在还不是embedding mid: 8 * 110 rel: 8 (每个元素小于rel_nums 3) pat: 8, (每个数字代表对应的pattern_id), 如果无对应的,就是-1 patterns: pattern_rels: pattern_num * rel_num, 每条pattern对应的relation pats_token: pattern_num * pat_token_len, 每条patterns的embedding weights: pattern的权重, 维度: patterns_num """ self.is_train = is_train device = self.config.device sent = torch.from_numpy(sent).long().to(device) mid = torch.from_numpy(mid).long().to(device) rel_label = torch.from_numpy(rel_label).long().to(device) pat_label = torch.from_numpy(pat_label).long().to(device) pattern_rels = torch.from_numpy(pattern_rels).float().to(device) pats = torch.from_numpy(pats).long().to(device) weights = torch.from_numpy(weights).float().to(device) sent_tokens = torch.from_numpy(sent_tokens).long().to(device) pats_tokens = torch.from_numpy(pats_tokens).long().to(device) sent_tokens_mask = torch.from_numpy(sent_tokens_mask).bool().to(device) pats_tokens_mask = torch.from_numpy(pats_tokens_mask).bool().to(device) rel_label = torch.argmax(rel_label, -1) pattern_rels_label = torch.argmax(pattern_rels, -1) sent_mask = sent.bool() sent_len = torch.sum(sent_mask, dim=1) sent_max_len = torch.max(sent_len) sent_mask = sent_mask[:, :sent_max_len] sent = sent[:, :sent_max_len] mid_mask = mid.bool() mid_len = torch.sum(mid_mask, dim=1) mid_max_len = torch.max(mid_len) mid_mask = mid_mask[:, :mid_max_len] mid = mid[:, :mid_max_len] pat_mask = pats.bool() pat_len = torch.sum(pat_mask, dim=1) pat_max_len = torch.max(pat_len) pat_mask = pat_mask[:, :pat_max_len] pat = pats[:, :pat_max_len] sent_embedding = self.get_embedding(sent) mid_embedding = self.get_embedding(mid) pat_embedding = self.get_embedding(pats) # encoder sent_d = self.bert(sent_tokens, attention_mask=sent_tokens_mask)[0][:, 0, :] pat_d = self.bert_no_grad(pats_tokens, attention_mask=pats_tokens_mask)[0][:, 0, :] # similarity sim, pat_sim = self.att_match(mid_embedding, pat_embedding, mid_mask, pat_mask, self.keep_prob, self.is_train) neg_idxs = torch.matmul(pattern_rels, torch.transpose(pattern_rels, 1, 0)) pat_pos = torch.square( torch.max(self.config.tau - pat_sim, torch.zeros_like(pat_sim))) pat_pos = torch.max(pat_pos - (1 - neg_idxs) * 1e30, dim=1)[0] pat_neg = torch.square(torch.max(pat_sim, torch.zeros_like(pat_sim))) pat_neg = torch.max(pat_neg - 1e30 * neg_idxs, dim=1)[0] l_sim = torch.sum(weights * (pat_pos + pat_neg), dim=0) logit = self.fc_sent2rel(sent_d) pred = F.softmax(logit, dim=1) if self.is_train is True: l_a = F.cross_entropy(logit[:self.config.gt_batch_size], rel_label[:self.config.gt_batch_size]) xsim = sim[self.config.gt_batch_size:] # xsim = xsim.detach() # xsim.requires_grad = False pseudo_rel = pattern_rels_label[torch.argmax(xsim, dim=1)] bound = torch.max(xsim, dim=1)[0] weight = F.softmax(10 * bound, dim=0) l_u = torch.sum(weight * F.cross_entropy(logit[self.config.gt_batch_size:], pseudo_rel, reduction='none')) pat2rel = self.fc_pat2rel(pat_d) pat2rel_pred = F.softmax(pat2rel, dim=1) l_pat = F.cross_entropy(pat2rel_pred, pattern_rels_label) loss = l_a + self.config.alpha * l_pat + self.config.gamma * l_u + self.config.beta * l_sim # loss = l_a + self.config.alpha * l_pat + self.config.beta * l_u else: loss = 0.0 preds = torch.argmax(pred, dim=1) val = torch.sum((0 - torch.log(torch.clamp(pred, 1e-5, 1.0))) * pred, dim=1) golds = rel_label return golds, preds, val, loss
def forward( self, # type: ignore tokens: TextFieldTensors, spans: torch.IntTensor, ner_labels: torch.IntTensor = None, rel_span_indices: torch.IntTensor = None, rel_labels: torch.IntTensor = None, span_masks: torch.IntTensor = None, relation_masks: torch.IntTensor = None, rels_sample_masks: torch.BoolTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, Dict]: embedded_text_input = self._text_field_embedder(tokens) try: entity_ctx = self._bert( tokens["tokens"]["token_ids"]).last_hidden_state[:, 0, :] except AttributeError: entity_ctx = self._bert(tokens["tokens"]["token_ids"])[0][:, 0, :] embedded_text_input = torch.cat( (entity_ctx.unsqueeze(1), embedded_text_input), dim=1) batch_size = embedded_text_input.shape[0] entity_sizes = spans[:, :, 1] - spans[:, :, 0] + 1 size_embeddings = self.size_embeddings(entity_sizes) entity_clf, entity_spans_pool = self._classify_entities( embedded_text_input, span_masks, size_embeddings, entity_ctx) # TODO If we have no gold entities, we cannot specify relation candidates! # entity_max_logits_index = entity_clf.max(dim=2).indices # relation_candidates = [] # relation_masks = [] # for batch in range(entity_max_logits_index.shape[0]): # # entity_indices = entity_max_logits_index[batch].nonzero(as_tuple=True)[0] # # new_candidates = list(itertools.permutations(entity_indices.tolist(), 2)) # # for nc in new_candidates: # # start_entity_span = tuple(spans[batch][nc[0]].tolist()) # end_entity_span = tuple(spans[batch][nc[1]].tolist()) # # relation_masks += [create_rel_mask(start_entity_span, end_entity_span, embedded_text_input.shape[1])] # # relation_candidates += [] #TODO wir haben zur evaluation KEINE Label, die zu diesen labeln passen, wir müssen die von Span Labeling usw. wieder nutzen! # rel_span_indices = torch.tensor(relation_candidates, device=entity_clf.device) # classify relations if rel_labels is None: ctx_size = embedded_text_input.shape[1] entity_sample_masks = torch.ones((batch_size, entity_clf.shape[1])) rel_span_indices, relation_masks, rel_sample_masks = self._filter_spans( entity_clf, spans, entity_sample_masks, ctx_size) rel_sample_masks = rel_sample_masks.float().unsqueeze(-1) h_large = embedded_text_input.unsqueeze(1).repeat( 1, max(min(rel_span_indices.shape[1], self._max_pairs), 1), 1, 1) rel_clf = torch.zeros( [batch_size, rel_span_indices.shape[1], self._relation_types]).to(self.rel_classifier.weight.device) else: h_large = embedded_text_input.unsqueeze(1).repeat( 1, max(min(rel_span_indices.shape[1], self._max_pairs), 1), 1, 1) rel_clf = torch.zeros( [batch_size, rel_span_indices.shape[1], self._relation_types]).to(self.rel_classifier.weight.device) # obtain relation logits # chunk processing to reduce memory usage for i in range(0, rel_span_indices.shape[1], self._max_pairs): # classify relation candidates chunk_rel_logits = self._classify_relations( entity_spans_pool, size_embeddings, rel_span_indices, relation_masks, h_large, i) chunk_rel_clf = torch.sigmoid(chunk_rel_logits) rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf converted_relations = [] for batch in range(batch_size): batch_pred_entities, batch_pred_relations = self.convert_predictions( entity_clf[batch].unsqueeze(0), rel_clf[batch].unsqueeze(0), rel_span_indices[batch].unsqueeze(0), spans[batch].unsqueeze(0), entity_sample_masks[batch].unsqueeze(0), self._rel_filter_threshold, ) batch_converted_relations = [] for pred_relation in batch_pred_relations[0]: h_name, t_name = sorted(relation_args_names[pred_relation[2]]) converted_relation = { "name": pred_relation[2], "ents": [ { "name": h_name, "start": pred_relation[0][0], "end": pred_relation[0][1], }, { "name": t_name, "start": pred_relation[1][0], "end": pred_relation[1][1], }, ] } batch_converted_relations += [converted_relation] converted_relations += [batch_converted_relations] if ner_labels and rel_labels: batch_loss = self.compute_loss(entity_logits=entity_clf, rel_logits=rel_clf, rel_types=rel_labels, entity_types=ner_labels, rel_sample_masks=rels_sample_masks) self._f1_entities(entity_clf, ner_labels) #self._f1_relation(rel_clf, rel_labels.bool()) self._f1_relation(rel_clf.squeeze(), rel_labels.bool().squeeze()) return {"loss": batch_loss} return {"relations": converted_relations}