def argmax(self) -> Tuple[LongTensor, LongTensor]: """Compute the most probable labeled dependency tree. Returns: - Tensor of shape (B, N) containing the head positions of the best tree. - Tensor of shape (B, N) containing the dependency types for the corresponding head-dependent relation. """ assert self.mask is not None # each shape: (bsz, slen, slen) scores, best_types = self.scores.max(dim=3) lengths = self.mask.long().sum(dim=1) if self.proj: crf = DependencyCRF(_unconvert(scores), lengths - 1, multiroot=self.multiroot) # shape: (bsz, slen) _, pred_heads = _convert(crf.argmax).max(dim=1) pred_heads[:, self.ROOT] = self.ROOT else: if not self.multiroot: warnings.warn( "argmax for non-projective is still multiroot although multiroot=False" ) # shape: (bsz, slen) pred_heads = find_mst(scores, lengths.tolist()) # shape: (bsz, slen) pred_types = best_types.gather(1, pred_heads.unsqueeze(1)).squeeze(1) return pred_heads, pred_types # type: ignore
def marginals(self) -> Tensor: """Compute the arc marginal probabilities. Returns: Tensor of shape (B, N, N, L) containing the arc marginal probabilities. """ assert self.mask is not None if self.proj: lengths = self.mask.long().sum(dim=1) crf = DependencyCRF(_unconvert(self.scores), lengths - 1, multiroot=self.multiroot) margs = _convert(crf.marginals) # marginals of incoming arcs to root are zero margs[:, :, self.ROOT] = 0 # marginals of self-loops are zero self_loop_mask = torch.eye(margs.size(1)).to( margs.device).unsqueeze(2).bool() margs = margs.masked_fill(self_loop_mask, 0) return margs return compute_marginals(self.scores, self.mask, self.multiroot)
def train(train_iter, val_iter, model): opt = AdamW(model.parameters(), lr=1e-4, eps=1e-8) scheduler = WarmupLinearSchedule(opt, warmup_steps=20, t_total=2500) model.train() losses = [] for i, ex in enumerate(train_iter): opt.zero_grad() words, mapper, _ = ex.word label, lengths = ex.head batch, _ = label.shape # Model final = model(words.cuda(), mapper) for b in range(batch): final[b, lengths[b]-1:, :] = 0 final[b, :, lengths[b]-1:] = 0 if not lengths.max() <= final.shape[1] + 1: print("fail") continue dist = DependencyCRF(final, lengths=lengths) labels = dist.struct.to_parts(label, lengths=lengths).type_as(final) log_prob = dist.log_prob(labels) loss = log_prob.sum() (-loss).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() scheduler.step() losses.append(loss.detach()) if i % 50 == 1: print(-torch.tensor(losses).mean(), words.shape) losses = [] if i % 600 == 500: validate(val_iter)
def log_partitions(self) -> Tensor: """Compute the log partition function. Returns: 1-D tensor of length B containing the log partition functions. """ assert self.mask is not None if self.proj: lengths = self.mask.long().sum(dim=1) crf = DependencyCRF(_unconvert(self.scores), lengths - 1, multiroot=self.multiroot) return crf.partition return compute_log_partitions(self.scores, self.mask, self.multiroot)
def validate(val_iter): incorrect_edges = 0 total_edges = 0 model.eval() for i, ex in enumerate(val_iter): words, mapper, _ = ex.word label, lengths = ex.head batch, _ = label.shape final = model(words.cuda(), mapper) for b in range(batch): final[b, lengths[b]-1:, :] = 0 final[b, :, lengths[b]-1:] = 0 dist = DependencyCRF(final, lengths=lengths) gold = dist.struct.to_parts(label, lengths=lengths).type_as(dist.argmax) incorrect_edges += (dist.argmax[:, :].cpu() - gold[:, :].cpu()).abs().sum() / 2.0 total_edges += gold.sum() print(total_edges, incorrect_edges) model.train()
nb_tr_steps += 1 torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm) optimizer.step() scheduler.step() else: b_tags = [tag[mask] for mask, tag in zip(b_label_masks, b_tags)] b_tags = pad_sequence(b_tags, batch_first=True, padding_value=0) loss_main, logits, labels, final = model(b_input_ids, b_tags, labels=b_labels, label_masks=b_label_masks) if not lengths.max() <= final.shape[1]: dep_loss = 0 else: dist = DependencyCRF(final, lengths=lengths) dep_labels = dist.struct.to_parts(b_tags, lengths=lengths).type_as(final) # [BATCH_SIZE, lengths, lengths] log_prob = dist.log_prob(dep_labels) dep_loss = log_prob.mean() #sum() if dep_loss < 0 : loss = loss_main -dep_loss/dep_loss_factor else: loss = loss_main loss.backward() tr_loss += loss.item() nb_tr_steps += 1 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
def eval(iter_data, model, tags2idx, device_name, mt=False): device = device_name logger.info("starting to evaluate") model = model.eval() eval_loss, eval_accuracy = 0, 0 nb_eval_steps = 0 predictions, true_labels, data_instances, probs = [], [], [], [] dep_predictions, dep_gold_labels = [], [] total_edges = 0 incorrect_edges = 0 for batch in tqdm(iter_data): batch = tuple(t.to(device) for t in batch) b_input_ids, b_pos_ids, b_tag_ids, b_deptype_ids, b_labels, b_input_mask,\ b_token_type_ids, b_label_masks, lengths = batch #print('b_input_ids size:', b_input_ids.size()) # batch_size*max_len lengths = torch.flatten(lengths) batch_size, _ = b_tag_ids.shape with torch.no_grad(): if not mt: tmp_eval_loss, logits, reduced_labels = model( b_input_ids, b_tag_ids, token_type_ids=b_token_type_ids, attention_mask=b_input_mask, labels=b_labels, label_masks=b_label_masks) else: tmp_eval_loss, logits, reduced_labels, final = model( b_input_ids, b_tag_ids, labels=b_labels, label_masks=b_label_masks) if not lengths.max() <= final.shape[1]: #+ 1: #print("fail to evaluate for dependency:", "max length", lengths.max(), "final shape", final.shape[1]) #continue out = torch.zeros( b_tags.size()) # not sure about the size! # I cannot think what the size should be else: dist = DependencyCRF(final, lengths=lengths) out = dist.argmax dep_predictions.append(out) b_tags = [ tag[mask] for mask, tag in zip(b_label_masks, b_tag_ids) ] b_tags = pad_sequence(b_tags, batch_first=True, padding_value=0) dep_gold = dist.struct.to_parts( b_tags, lengths=lengths).type_as(out) dep_gold_labels.append(dep_gold) incorrect_edges += (out[:, :].cpu() - dep_gold[:, :].cpu()).abs().sum() / 2.0 total_edges += dep_gold.sum() #log_prob = dist.log_prob(dep_labels) #dep_loss = log_prob.sum() tags_idx = [tags2idx[t] for t in tags2idx] logits_probs = F.softmax(logits, dim=2)[:, :, tags_idx] preds = torch.argmax(F.log_softmax(logits, dim=2), dim=2) #print('***',logits_probs) #print('logits size:',logits.size()) # batch_size*sentence_len(before padding) logits_probs = logits_probs.detach().cpu().numpy() preds = preds.detach().cpu().numpy() reduced_labels = reduced_labels.to('cpu').numpy() labels_to_append = [] predictions_to_append = [] logits_to_append = [] for prediction, r_label, logit in zip(preds, reduced_labels, logits_probs): preds = [] labels = [] logs = [] for pred, lab, log in zip(prediction, r_label, logit): if lab.item( ) == -1: # masked label; -1 means do not collect this label continue preds.append(pred) labels.append(lab) logs.append(log) predictions_to_append.append(preds) labels_to_append.append(labels) logits_to_append.append(logs) predictions.extend(predictions_to_append) true_labels.extend(labels_to_append) data_instances.extend(b_input_ids) probs.extend(logits_to_append) eval_loss += tmp_eval_loss.mean().item() nb_eval_steps += 1 if mt: print('num of edges', total_edges, 'incorrect_edges:', incorrect_edges) print('aacuracy', (total_edges - incorrect_edges) / total_edges) eval_loss = eval_loss / nb_eval_steps logger.info("eval loss (only main): {}".format(eval_loss)) idx2tags = {tags2idx[t]: t for t in tags2idx} pred_tags = [[idx2tags[p_i] for p_i in p] for p in predictions] valid_tags = [[idx2tags[l_i] for l_i in l] for l in true_labels] logger.info("Seqeval accuracy: {}".format( accuracy_score(valid_tags, pred_tags))) fscore = f1_score(valid_tags, pred_tags) logger.info("Seqeval F1-Score: {}".format(fscore)) logger.info("Seqeval Classification report: -- ") logger.info(classification_report(valid_tags, pred_tags)) final_labels = [[idx2tags[p_i] for p_i in p] for p in predictions] return final_labels, probs, fscore