def evaluate(self, loader, punct=False): self.parser.eval() metric = Metric() pbar = tqdm(total=len(loader)) for words, tags, masks, heads, rels, mask_heads in loader: states = [ State(mask, tags.device, self.vocab.bert_index, self.config.input_graph) for mask in masks ] states = self.parser(words, tags, masks, states) pred_heads = [] pred_rels = [] for state in states: pred_heads.append([h[0] for h in state.head][1:]) pred_rels.append([h[1] for h in state.head][1:]) pred_heads = [item for sublist in pred_heads for item in sublist] pred_rels = [item for sublist in pred_rels for item in sublist] pred_heads = torch.tensor(pred_heads).to(heads.device) pred_rels = torch.tensor(pred_rels).to(heads.device) heads = heads[mask_heads] rels = rels[mask_heads] pbar.update(1) metric(pred_heads, pred_rels, heads, rels) del states return metric
def predict(self, loader): self.parser.eval() metric = Metric() pbar = tqdm(total=len(loader)) all_arcs, all_rels = [], [] for words, tags, masks, heads, rels, mask_heads in loader: states = [ State(mask, tags.device, self.vocab.bert_index, self.config.input_graph) for mask in masks ] states = self.parser(words, tags, masks, states) pred_heads = [] pred_rels = [] for state in states: pred_heads.append([h[0] for h in state.head][1:]) pred_rels.append([h[1] for h in state.head][1:]) pred_heads = [item for sublist in pred_heads for item in sublist] pred_rels = [item for sublist in pred_rels for item in sublist] pred_heads = torch.tensor(pred_heads).to(heads.device) pred_rels = torch.tensor(pred_rels).to(heads.device) heads = heads[mask_heads] rels = rels[mask_heads] metric(pred_heads, pred_rels, heads, rels) lens = masks.sum(1).tolist() all_arcs.extend(torch.split(pred_heads, lens)) all_rels.extend(torch.split(pred_rels, lens)) pbar.update(1) all_arcs = [seq.tolist() for seq in all_arcs] all_rels = [self.vocab.id2rel(seq) for seq in all_rels] return all_arcs, all_rels, metric
def evaluate(self, loader, punct=False): self.parser.eval() loss, metric = 0, Metric() for words, tags, arcs, rels in loader: mask = words.ne(self.vocab.pad_index) # ignore the first token of each sentence mask[:, 0] = 0 # ignore all punctuation if not specified if not punct: puncts = words.new_tensor(self.vocab.puncts) mask &= words.unsqueeze(-1).ne(puncts).all(-1) s_arc, s_rel = self.parser(words, tags) s_arc, s_rel = s_arc[mask], s_rel[mask] gold_arcs, gold_rels = arcs[mask], rels[mask] pred_arcs, pred_rels = self.decode(s_arc, s_rel) loss += self.get_loss(s_arc, s_rel, gold_arcs, gold_rels) metric(pred_arcs, pred_rels, gold_arcs, gold_rels) loss /= len(loader) return loss, metric
def __call__(self, config): print("Preprocess the data") train = Corpus.load(config.ftrain) dev = Corpus.load(config.fdev) test = Corpus.load(config.ftest) if os.path.exists(config.vocab): vocab = torch.load(config.vocab) else: vocab = Vocab.from_corpus(corpus=train, min_freq=2) vocab.read_embeddings(Embedding.load(config.fembed, config.unk)) torch.save(vocab, config.vocab) config.update({ 'n_words': vocab.n_train_words, 'n_tags': vocab.n_tags, 'n_rels': vocab.n_rels, 'pad_index': vocab.pad_index, 'unk_index': vocab.unk_index }) print(vocab) print("Load the dataset") trainset = TextDataset(vocab.numericalize(train)) devset = TextDataset(vocab.numericalize(dev)) testset = TextDataset(vocab.numericalize(test)) # set the data loaders train_loader = batchify(dataset=trainset, batch_size=config.batch_size, n_buckets=config.buckets, shuffle=True) dev_loader = batchify(dataset=devset, batch_size=config.batch_size, n_buckets=config.buckets) test_loader = batchify(dataset=testset, batch_size=config.batch_size, n_buckets=config.buckets) print(f"{'train:':6} {len(trainset):5} sentences in total, " f"{len(train_loader):3} batches provided") print(f"{'dev:':6} {len(devset):5} sentences in total, " f"{len(dev_loader):3} batches provided") print(f"{'test:':6} {len(testset):5} sentences in total, " f"{len(test_loader):3} batches provided") print("Create the model") parser = BiaffineParser(config, vocab.embeddings) if torch.cuda.is_available(): parser = parser.cuda() print(f"{parser}\n") model = Model(vocab, parser) total_time = timedelta() best_e, best_metric = 1, Metric() model.optimizer = Adam(model.parser.parameters(), config.lr, (config.beta_1, config.beta_2), config.epsilon) model.scheduler = ExponentialLR(model.optimizer, config.decay ** (1 / config.steps)) for epoch in range(1, config.epochs + 1): start = datetime.now() # train one epoch and update the parameters model.train(train_loader) print(f"Epoch {epoch} / {config.epochs}:") loss, train_metric = model.evaluate(train_loader, config.punct) print(f"{'train:':6} Loss: {loss:.4f} {train_metric}") loss, dev_metric = model.evaluate(dev_loader, config.punct) print(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}") loss, test_metric = model.evaluate(test_loader, config.punct) print(f"{'test:':6} Loss: {loss:.4f} {test_metric}") t = datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric and epoch > config.patience: best_e, best_metric = epoch, dev_metric model.parser.save(config.model + f".{best_e}") print(f"{t}s elapsed (saved)\n") else: print(f"{t}s elapsed\n") total_time += t if epoch - best_e >= config.patience: break model.parser = BiaffineParser.load(config.model + f".{best_e}") loss, metric = model.evaluate(test_loader, config.punct) print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}") print(f"the score of test at epoch {best_e} is {metric.score:.2%}") print(f"average time of each epoch is {total_time / epoch}s") print(f"{total_time}s elapsed")
def __call__(self, config): print("Preprocess the data") train = Corpus.load(config.ftrain) dev = Corpus.load(config.fdev) test = Corpus.load(config.ftest) if path.exists(config.model) != True: os.mkdir(config.model) if path.exists("model/") != True: os.mkdir("model/") if path.exists(config.model + config.modelname) != True: os.mkdir(config.model + config.modelname) if config.checkpoint: vocab = torch.load(config.main_path + config.vocab + config.modelname + "/vocab.tag") else: vocab = Vocab.from_corpus(config=config, corpus=train, corpus_dev=dev, corpus_test=test, min_freq=0) train_seq = read_seq(config.ftrain_seq, vocab) total_act = 0 for x in train_seq: total_act += len(x) print("number of transitions:{}".format(total_act)) torch.save(vocab, config.vocab + config.modelname + "/vocab.tag") config.update({ 'n_words': vocab.n_train_words, 'n_tags': vocab.n_tags, 'n_rels': vocab.n_rels, 'n_trans': vocab.n_trans, 'pad_index': vocab.pad_index, 'unk_index': vocab.unk_index }) print("Load the dataset") trainset = TextDataset(vocab.numericalize(train, train_seq)) devset = TextDataset(vocab.numericalize(dev)) testset = TextDataset(vocab.numericalize(test)) # set the data loaders train_loader, _ = batchify(dataset=trainset, batch_size=config.batch_size, n_buckets=config.buckets, shuffle=True) dev_loader, _ = batchify(dataset=devset, batch_size=config.batch_size, n_buckets=config.buckets) test_loader, _ = batchify(dataset=testset, batch_size=config.batch_size, n_buckets=config.buckets) print(f"{'train:':6} {len(trainset):5} sentences in total, " f"{len(train_loader):3} batches provided") print(f"{'dev:':6} {len(devset):5} sentences in total, " f"{len(dev_loader):3} batches provided") print(f"{'test:':6} {len(testset):5} sentences in total, " f"{len(test_loader):3} batches provided") print("Create the model") if config.checkpoint: parser = Parser.load(config.main_path + config.model + config.modelname + "/parser-checkpoint") else: parser = Parser(config, vocab.bertmodel) print("number of parameters:{}".format( sum(p.numel() for p in parser.parameters() if p.requires_grad))) if torch.cuda.is_available(): print('Train/Evaluate on GPU') device = torch.device('cuda') parser = parser.to(device) model = Model(vocab, parser, config, vocab.n_rels) total_time = timedelta() best_e, best_metric = 1, Metric() ## prepare optimisers num_train_optimization_steps = int(config.epochs * len(train_loader)) warmup_steps = int(config.warmupproportion * num_train_optimization_steps) ## one for parsing parameters, one for BERT parameters if config.use_two_opts: model_nonbert = [] model_bert = [] layernorm_params = [ 'layernorm_key_layer', 'layernorm_value_layer', 'dp_relation_k', 'dp_relation_v' ] for name, param in parser.named_parameters(): if 'bert' in name and not any(nd in name for nd in layernorm_params): model_bert.append((name, param)) else: model_nonbert.append((name, param)) # Prepare optimizer and schedule (linear warmup and decay) for Non-bert parameters no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters_nonbert = [{ 'params': [ p for n, p in model_nonbert if not any(nd in n for nd in no_decay) ], 'weight_decay': config.weight_decay }, { 'params': [ p for n, p in model_nonbert if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] model.optimizer_nonbert = AdamW( optimizer_grouped_parameters_nonbert, lr=config.lr2) model.scheduler_nonbert = get_linear_schedule_with_warmup( model.optimizer_nonbert, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps) # Prepare optimizer and schedule (linear warmup and decay) for Bert parameters optimizer_grouped_parameters_bert = [{ 'params': [ p for n, p in model_bert if not any(nd in n for nd in no_decay) ], 'weight_decay': config.weight_decay }, { 'params': [p for n, p in model_bert if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] model.optimizer_bert = AdamW(optimizer_grouped_parameters_bert, lr=config.lr) model.scheduler_bert = get_linear_schedule_with_warmup( model.optimizer_bert, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps) else: # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in parser.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': config.weight_decay }, { 'params': [ p for n, p in parser.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] model.optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr) model.scheduler = get_linear_schedule_with_warmup( model.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps) start_epoch = 1 ## load model, optimiser, and other parameters from a checkpoint if config.checkpoint: check_load = torch.load(config.main_path + config.model + config.modelname + "/checkpoint") if config.use_two_opts: model.optimizer_bert.load_state_dict( check_load['optimizer_bert']) model.optimizer_nonbert.load_state_dict( check_load['optimizer_nonbert']) model.scheduler_bert.load_state_dict( check_load['lr_schedule_bert']) model.scheduler_nonbert.load_state_dict( check_load['lr_schedule_nonbert']) start_epoch = check_load['epoch'] + 1 best_e = check_load['best_e'] best_metric = check_load['best_metric'] else: model.optimizer.load_state_dict(check_load['optimizer']) model.scheduler.load_state_dict(check_load['lr_schedule']) start_epoch = check_load['epoch'] + 1 best_e = check_load['best_e'] best_metric = check_load['best_metric'] f1 = open(config.model + config.modelname + "/baseline.txt", "a") f1.write("New Model:\n") f1.close() for epoch in range(start_epoch, config.epochs + 1): start = datetime.now() # train one epoch and update the parameters model.train(train_loader) print(f"Epoch {epoch} / {config.epochs}:") f1 = open(config.model + config.modelname + "/baseline.txt", "a") dev_metric = model.evaluate(dev_loader, config.punct) f1.write(str(epoch) + "\n") print(f"{'dev:':6} {dev_metric}") f1.write(f"{'dev:':6} {dev_metric}") f1.write("\n") f1.close() t = datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric print(config.model + config.modelname + "/model_weights") model.parser.save(config.model + config.modelname + "/model_weights") print(f"{t}s elapsed (saved)\n") else: print(f"{t}s elapsed\n") total_time += t if epoch - best_e >= config.patience: break ## save checkpoint if config.use_two_opts: checkpoint = { "epoch": epoch, "optimizer_bert": model.optimizer_bert.state_dict(), "lr_schedule_bert": model.scheduler_bert.state_dict(), "lr_schedule_nonbert": model.scheduler_nonbert.state_dict(), "optimizer_nonbert": model.optimizer_nonbert.state_dict(), 'best_metric': best_metric, 'best_e': best_e } torch.save( checkpoint, config.main_path + config.model + config.modelname + "/checkpoint") parser.save(config.main_path + config.model + config.modelname + "/parser-checkpoint") else: checkpoint = { "epoch": epoch, "optimizer": model.optimizer.state_dict(), "lr_schedule": model.scheduler.state_dict(), 'best_metric': best_metric, 'best_e': best_e } torch.save( checkpoint, config.main_path + config.model + config.modelname + "/checkpoint") parser.save(config.main_path + config.model + config.modelname + "/parser-checkpoint") model.parser = Parser.load(config.model + config.modelname + "/model_weights") metric = model.evaluate(test_loader, config.punct) print(metric) print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}") print(f"the score of test at epoch {best_e} is {metric.score:.2%}") print(f"average time of each epoch is {total_time / epoch}s") print(f"{total_time}s elapsed")
def __call__(self, config): print("Preprocess the data") if config.input_type == "conllu": train = UniversalDependenciesDatasetReader() train.load(config.ftrain) dev = UniversalDependenciesDatasetReader() dev.load(config.fdev) test = UniversalDependenciesDatasetReader() test.load(config.ftest) else: train = Corpus.load(config.ftrain) dev = Corpus.load(config.fdev) test = Corpus.load(config.ftest) if config.use_predicted: if config.input_type == "conllu": train_predicted = UniversalDependenciesDatasetReader() train_predicted.load(config.fpredicted_train) dev_predicted = UniversalDependenciesDatasetReader() dev_predicted.load(config.fpredicted_dev) test_predicted = UniversalDependenciesDatasetReader() test_predicted.load(config.fpredicted_test) else: train_predicted = Corpus.load(config.fpredicted_train) dev_predicted = Corpus.load(config.fpredicted_dev) test_predicted = Corpus.load(config.fpredicted_test) if path.exists(config.main_path + "/exp") != True: os.mkdir(config.main_path + "/exp") if path.exists(config.main_path + "/model") != True: os.mkdir(config.main_path + "/model") if path.exists(config.main_path + config.model + config.modelname) != True: os.mkdir(config.main_path + config.model + config.modelname) vocab = Vocab.from_corpus(config=config, corpus=train, min_freq=2) torch.save(vocab, config.main_path + config.vocab + config.modelname + "/vocab.tag") config.update({ 'n_words': vocab.n_train_words, 'n_tags': vocab.n_tags, 'n_rels': vocab.n_rels, 'pad_index': vocab.pad_index, 'unk_index': vocab.unk_index }) print("Load the dataset") if config.use_predicted: trainset = TextDataset(vocab.numericalize(train, train_predicted)) devset = TextDataset(vocab.numericalize(dev, dev_predicted)) testset = TextDataset(vocab.numericalize(test, test_predicted)) else: trainset = TextDataset(vocab.numericalize(train)) devset = TextDataset(vocab.numericalize(dev)) testset = TextDataset(vocab.numericalize(test)) # set the data loaders train_loader, _ = batchify(dataset=trainset, batch_size=config.batch_size, n_buckets=config.buckets, shuffle=True) dev_loader, _ = batchify(dataset=devset, batch_size=config.batch_size, n_buckets=config.buckets) test_loader, _ = batchify(dataset=testset, batch_size=config.batch_size, n_buckets=config.buckets) print(f"{'train:':6} {len(trainset):5} sentences in total, " f"{len(train_loader):3} batches provided") print(f"{'dev:':6} {len(devset):5} sentences in total, " f"{len(dev_loader):3} batches provided") print(f"{'test:':6} {len(testset):5} sentences in total, " f"{len(test_loader):3} batches provided") print("Create the model") parser = BiaffineParser(config, vocab.n_rels, vocab.bertmodel) print("number of pars:{}".format(sum(p.numel() for p in parser.parameters() if p.requires_grad))) if torch.cuda.is_available(): print('device:cuda') device = torch.device('cuda') parser = parser.to(device) # print(f"{parser}\n") model = Model(vocab, parser, config, vocab.n_rels) total_time = timedelta() best_e, best_metric = 1, Metric() num_train_optimization_steps = int(config.num_iter_encoder * config.epochs * len(train_loader)) warmup_steps = int(config.warmupproportion * num_train_optimization_steps) if config.use_two_opts: model_nonbert = [] model_bert = [] layernorm_params = ['layernorm_key_layer', 'layernorm_value_layer', 'dp_relation_k', 'dp_relation_v'] for name, param in parser.named_parameters(): if 'bert' in name and not any(nd in name for nd in layernorm_params): model_bert.append((name, param)) else: model_nonbert.append((name, param)) # Prepare optimizer and schedule (linear warmup and decay) for Non-bert parameters no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters_nonbert = [ {'params': [p for n, p in model_nonbert if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay}, {'params': [p for n, p in model_nonbert if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] model.optimizer_nonbert = AdamW(optimizer_grouped_parameters_nonbert, lr=config.lr2) model.scheduler_nonbert = get_linear_schedule_with_warmup(model.optimizer_nonbert, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps) # Prepare optimizer and schedule (linear warmup and decay) for Bert parameters optimizer_grouped_parameters_bert = [ {'params': [p for n, p in model_bert if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay}, {'params': [p for n, p in model_bert if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] model.optimizer_bert = AdamW(optimizer_grouped_parameters_bert, lr=config.lr1) model.scheduler_bert = get_linear_schedule_with_warmup( model.optimizer_bert, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps ) else: # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in parser.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': config.weight_decay}, {'params': [p for n, p in parser.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] model.optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr1) model.scheduler = get_linear_schedule_with_warmup( model.optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_optimization_steps ) for epoch in range(1, config.epochs + 1): start = datetime.now() # train one epoch and update the parameters if config.use_predicted: model.train_predicted(train_loader) else: model.train(train_loader) print(f"Epoch {epoch} / {config.epochs}:") if config.use_predicted: loss, dev_metric = model.evaluate_predicted(dev_loader, config.punct) else: loss, dev_metric = model.evaluate(dev_loader, config.punct) print(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}") if config.use_predicted: loss, test_metric = model.evaluate_predicted(test_loader, config.punct) else: loss, test_metric = model.evaluate(test_loader, config.punct) print(f"{'test:':6} Loss: {loss:.4f} {test_metric}") t = datetime.now() - start # save the model if it is the best so far if dev_metric > best_metric: best_e, best_metric = epoch, dev_metric print(config.model + config.modelname + "/model_weights") model.parser.save(config.main_path + config.model + config.modelname + "/model_weights") print(f"{t}s elapsed (saved)\n") else: print(f"{t}s elapsed\n") total_time += t if epoch - best_e >= config.patience: break model.parser = BiaffineParser.load(config.main_path + config.model + config.modelname + "/model_weights") if config.use_predicted: loss, metric = model.evaluate_predicted(test_loader, config.punct) else: loss, metric = model.evaluate(test_loader, config.punct) print(metric) print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}") print(f"the score of test at epoch {best_e} is {metric.score:.2%}") print(f"average time of each epoch is {total_time / epoch}s") print(f"{total_time}s elapsed")
def predict(self, loader): self.parser.eval() metric = Metric() all_arcs, all_rels = [], [] for words, tags, arcs, rels, mask, sbert_arc, sbert_rel, offsets in loader: stop_sign = torch.ones(len(words)).long().to(words.device) mask_gold = arcs > 0 mask_unused = mask.clone() ## iterate over encoder for counter in range(0, self.config.num_iter_encoder): self.counter_ref = counter if counter == 0: s_arc, s_rel = self.parser(words, tags) s_arc_final = s_arc s_rel_final = s_rel else: if self.config.use_mst_train or counter == 1: graph_arc, graph_rel = \ self.prepare_mst(s_arc, s_rel, mask, sbert_arc.clone(), sbert_rel.clone(), stop_sign, False) else: graph_arc, graph_rel = \ self.prepare_argmax(s_arc, s_rel, mask, sbert_arc.clone(), sbert_rel.clone(), stop_sign, False) s_arc, s_rel = self.parser(words, tags, stop_sign, graph_arc, graph_rel) if self.config.use_mst_train or counter == 1: new_arcs, new_rels = self.prepare_mst( s_arc, s_rel, mask, sbert_arc.clone(), sbert_rel.clone(), stop_sign, True) else: new_arcs, new_rels = self.prepare_argmax( s_arc, s_rel, mask, sbert_arc.clone(), sbert_rel.clone(), stop_sign, True) if counter > 0: stop_sign = self.check_stop(stop_sign, new_arcs, prev_arcs, new_rels, prev_rels, mask) if stop_sign.sum() == 0: # print('All Dependency Graphs are converged in this batch') break mask = (stop_sign.unsqueeze(1) * mask.long()).bool() prev_arcs = new_arcs prev_rels = new_rels index = stop_sign.nonzero() s_arc_final[index] = s_arc[index] s_rel_final[index] = s_rel[index] gold_arcs, gold_rels = arcs[mask_gold], rels[mask_gold] if self.config.use_mst_eval: pred_rels, pred_arcs_org, pred_arcs = self.decode_mst( s_arc_final, s_rel_final, mask_unused, prepare=False, do_predict=True) metric(pred_arcs, pred_rels, gold_arcs, gold_rels) lens = mask_unused.sum(1).tolist() all_arcs.extend(torch.split(pred_arcs_org, lens)) all_rels.extend(torch.split(pred_rels, lens)) all_arcs = [seq.tolist() for seq in all_arcs] all_rels = [self.vocab.id2rel(seq) for seq in all_rels] return all_arcs, all_rels, metric
def evaluate(self, loader, punct=False): self.parser.eval() loss, metric = 0, Metric() pbar = tqdm(total=len(loader)) for words, tags, arcs, rels, mask, sbert_arc, sbert_rel in loader: stop_sign = torch.ones(len(words)).long().to(words.device) mask_gold = arcs > 0 mask_unused = mask.clone() ## iterate over encoder for counter in range(self.config.num_iter_encoder): self.counter_ref = counter if counter == 0: s_arc, s_rel = self.parser(words, tags) s_arc_final = s_arc s_rel_final = s_rel else: if self.config.use_mst_train or counter == 1: graph_arc,graph_rel= \ self.prepare_mst(s_arc,s_rel,mask,sbert_arc.clone(),sbert_rel.clone(),stop_sign,False) else: graph_arc,graph_rel= \ self.prepare_argmax(s_arc,s_rel,mask,sbert_arc.clone(),sbert_rel.clone(),stop_sign,False) s_arc, s_rel = self.parser(words, tags, stop_sign, graph_arc, graph_rel) if self.config.use_mst_train or counter == 1: new_arcs, new_rels = self.prepare_mst( s_arc, s_rel, mask, sbert_arc.clone(), sbert_rel.clone(), stop_sign, True) else: new_arcs, new_rels = self.prepare_argmax( s_arc, s_rel, mask, sbert_arc.clone(), sbert_rel.clone(), stop_sign, True) if counter > 0: stop_sign = self.check_stop(stop_sign, new_arcs, prev_arcs, new_rels, prev_rels, mask) if stop_sign.sum() == 0: # print('All Dependency Graphs are converged in this batch') break mask = (stop_sign.unsqueeze(1) * mask.long()).bool() prev_arcs = new_arcs prev_rels = new_rels index = stop_sign.nonzero() s_arc_final[index] = s_arc[index] s_rel_final[index] = s_rel[index] if self.config.show_refinement: if counter == 0: self.initial_refinement_total(new_arcs.clone(), new_rels.clone(), mask_unused.clone(), arcs.clone(), rels.clone(), mask_gold.clone()) elif counter > 0: self.refinement_total(new_arcs.clone(), new_rels.clone(), mask_unused.clone(), stop_sign.clone()) gold_arcs, gold_rels = arcs[mask_gold], rels[mask_gold] if self.config.use_mst_eval: pred_arcs, pred_rels = self.decode_mst(s_arc_final, s_rel_final, mask_unused, prepare=False) else: pred_arcs, pred_rels = self.decode(s_arc_final, s_rel_final, mask_unused) s_arc_mask = s_arc_final[mask_unused] s_rel_mask = s_rel_final[mask_unused] loss += self.get_loss(s_arc_mask, s_rel_mask, gold_arcs, gold_rels) metric(pred_arcs, pred_rels, gold_arcs, gold_rels) pbar.update(1) loss /= len(loader) return loss, metric