def forward(self, sentences, neg_samples, diora, info): batch_size, length = sentences.shape size = diora.outside_h.shape[-1] # Get the score for the ground truth tree. # gold_spans = self.makeLeftTree(info['spans']) gold_spans_r = self.makeRightTree(info['spans']) gold_scores = self.get_score_for_spans(sentences, diora.saved_scalars, gold_spans_r) # Get the score for maximal tree. parse_predictor = CKY(net=diora, word2idx=self.word2idx) max_trees = parse_predictor.parse_batch({'sentences': sentences}) max_spans = [tree_to_spans(x) for x in max_trees] max_scores = self.get_score_for_spans(sentences, diora.saved_scalars, max_spans) loss = max_scores - gold_scores + self.margin loss = loss.sum().view(1) / batch_size ret = dict(semi_supervised_parsing_loss=loss) return loss, ret
def run_parse(options, train_iterator, trainer, validation_iterator): logger = get_logger() validation_dataset = get_validation_dataset(options) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) # Parse diora = trainer.net.diora ## Turn off outside pass. trainer.net.diora.outside = False ## Eval mode. trainer.net.eval() ## Topk predictor. parse_predictor = CKY(net=diora, word2idx=word2idx) batches = validation_iterator.get_iterator(random_seed=options.seed) logger.info('Beginning to parse.') with torch.no_grad(): for i, batch_map in enumerate(batches): sentences = batch_map['sentences'] batch_size = sentences.shape[0] length = sentences.shape[1] # Rather than skipping, just log the trees (they are trivially easy to find). if length <= 2: for i in range(batch_size): example_id = batch_map['example_ids'][i] tokens = sentences[i].tolist() words = [idx2word[idx] for idx in tokens] if length == 2: o = dict(example_id=example_id, tree=(words[0], words[1])) elif length == 1: o = dict(example_id=example_id, tree=words[0]) print(json.dumps(o)) continue _ = trainer.step(batch_map, train=False, compute_loss=False) trees = parse_predictor.parse_batch(batch_map) for ii, tr in enumerate(trees): example_id = batch_map['example_ids'][ii] s = [idx2word[idx] for idx in sentences[ii].tolist()] tr = replace_leaves(tr, s) o = dict(example_id=example_id, tree=tr) print(json.dumps(o))
def forward(self, sentences, neg_samples, diora, info): batch_size, length = sentences.shape size = diora.outside_h.shape[-1] # TODO for semi-supervised create a "pseudo-tree" using external annotation (i.e. entity boundaries). # Get the score for the ground truth tree. gold_scores = self.get_score_for_spans(sentences, diora.saved_scalars, info['spans']) # Get the score for maximal tree. parse_predictor = CKY(net=diora, word2idx=self.word2idx) max_trees = parse_predictor.parse_batch({'sentences': sentences}) max_spans = [tree_to_spans(x) for x in max_trees] max_scores = self.get_score_for_spans(sentences, diora.saved_scalars, max_spans) loss = max_scores - gold_scores + self.margin loss = loss.sum().view(1) / batch_size ret = dict(semi_supervised_parsing_loss=loss) return loss, ret
class TreeHelper(object): def __init__(self, diora, word2idx): self.diora = diora self.word2idx = word2idx self.idx2word = {idx: w for w, idx in self.word2idx.items()} def init(self, options): if options.parse_mode == 'latent': self.parse_predictor = CKY(net=self.diora, word2idx=self.word2idx) ## Monkey patch parsing specific methods. override_init_with_batch(self.diora) override_inside_hook(self.diora) def get_trees_for_batch(self, batch_map, options): sentences = batch_map['sentences'] batch_size = sentences.shape[0] length = sentences.shape[1] # trees if options.parse_mode == 'all-spans': raise Exception('Does not support this mode.') elif options.parse_mode == 'latent': trees = self.parse_predictor.parse_batch(batch_map) elif options.parse_mode == 'given': trees = batch_map['trees'] # spans spans = [] for ii, tr in enumerate(trees): s = [self.idx2word[idx] for idx in sentences[ii].tolist()] tr = replace_leaves(tr, s) if options.postprocess: tr = postprocess(tr, s) spans.append(tree_to_spans(tr)) return trees, spans
def forward_(self, sentences, neg_samples, diora, info): batch_size, length = sentences.shape size = diora.outside_h.shape[-1] # Get the score for the ground truth tree. gold_spans = self.makeLeftTree(info['spans']) gold_spans_r = self.makeRightTree(info['spans']) gold_scores = self.get_score_for_spans_modified( sentences, diora.saved_scalars, gold_spans, info['spans']) gold_scores_r = self.get_score_for_spans_modified( sentences, diora.saved_scalars, gold_spans_r, info['spans']) # print("gold score", gold_scores) # print("right gold score", gold_scores_r) #print('info spans', info['spans']) #print('ner spans', gold_spans) # Get the score for maximal tree. parse_predictor = CKY(net=diora, word2idx=self.word2idx) max_trees = parse_predictor.parse_batch({'sentences': sentences}) max_spans = [tree_to_spans(x) for x in max_trees] roots, diora_spans = self.findClosestParent(max_spans, info['spans']) # print('paresed spans', max_spans) # print('closest subtree spans', diora_spans) # print('closest roots', roots) gold_scores = self.get_score_for_spans_modified( sentences, diora.saved_scalars, gold_spans_r, info['spans']) max_scores = self.get_score_for_spans_modified(sentences, diora.saved_scalars, diora_spans, roots) total_loss = 0 # print('gold scores', gold_scores) # print('max scores', max_scores) for dp in range(len(gold_scores)): gold_score_data = gold_scores[dp] max_score_data = max_scores[dp] # print(gold_score_data, max_score_data) loss = 0 for i in range(len(gold_score_data)): if int(info['spans'][dp][i][0]) == int( roots[dp][i][0]) and int( info['spans'][dp][i][1]) == int(roots[dp][i][1]): # print(info['spans'][dp][i], roots[dp][i]) continue else: loss += max_score_data[i] - gold_score_data[i] + self.margin total_loss += loss #loss = max_scores - gold_scores + self.margin #loss = loss.sum().view(1) / batch_size loss = torch.tensor(total_loss / batch_size, requires_grad=True) # print("semi_supervised_parsing_loss", loss) ret = dict(semi_supervised_parsing_loss=loss) # print("-------------") # print('loss', loss) return loss, ret
def run(options): logger = get_logger() validation_dataset = get_validation_dataset(options) #print(validation_dataset['sentence1'][0],validation_dataset['example_ids'][0]) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) # Parse diora = trainer.net.encoder ## Monkey patch parsing specific methods. override_init_with_batch(diora) override_inside_hook(diora) ## Turn off outside pass. #trainer.net.encoder.outside = False ## Eval mode. trainer.net.eval() ## Parse predictor. parse_predictor = CKY(net=diora, word2idx=word2idx) batches = validation_iterator.get_iterator(random_seed=options.seed) output_path1 = os.path.abspath(os.path.join(options.experiment_path, 'parse_mnli1.jsonl')) output_path2 = os.path.abspath(os.path.join(options.experiment_path, 'parse_mnli2.jsonl')) logger.info('Beginning.') logger.info('Writing output to = {}'.format(output_path1)) logger.info('Writing output to = {}'.format(output_path2)) f = open(output_path1, 'w') with torch.no_grad(): for i, batch_map in tqdm(enumerate(batches)): #print(batch_map.keys()) sentences1 = batch_map['sentences_1'] sentences2 = batch_map['sentences_2'] #print(sentences.shape) batch_size = sentences1.shape[0] length = sentences1.shape[1] # Skip very short sentences. if length <= 2: continue _ = trainer.step(batch_map, train=False, compute_loss=False) trees1 = parse_predictor.parse_batch(sentences1) trees2 = parse_predictor.parse_batch(sentences2) #print(list(zip(trees1,trees2))) for ii,tree in enumerate(list(zip(trees1,trees2))): tr1,tr2 = tree[0],tree[1] example_id = batch_map['example_ids'][ii] #print(batch_map['example_ids']) s1 = [idx2word[idx] for idx in sentences1[ii].tolist()] s2 = [idx2word[idx] for idx in sentences2[ii].tolist()] tr1 = replace_leaves(tr1, s1) tr2 = replace_leaves(tr2, s2) if options.postprocess: tr = postprocess(tr, s1) o = collections.OrderedDict(example_id=example_id, sentence1=tr1,sentence2=tr2) #print(o) #exit() f.write(json.dumps(o) + '\n') f.close()
def init(self, options): if options.parse_mode == 'latent': self.parse_predictor = CKY(net=self.diora, word2idx=self.word2idx) ## Monkey patch parsing specific methods. override_init_with_batch(self.diora) override_inside_hook(self.diora)
def run(options): logger = get_logger() validation_dataset = get_validation_dataset(options) validation_iterator = get_validation_iterator(options, validation_dataset) word2idx = validation_dataset['word2idx'] embeddings = validation_dataset['embeddings'] idx2word = {v: k for k, v in word2idx.items()} logger.info('Initializing model.') trainer = build_net(options, embeddings, validation_iterator) # Parse diora = trainer.net.diora ## Monkey patch parsing specific methods. override_init_with_batch(diora) override_inside_hook(diora) ## Turn off outside pass. trainer.net.diora.outside = False ## Eval mode. trainer.net.eval() ## Parse predictor. parse_predictor = CKY(net=diora, word2idx=word2idx) batches = validation_iterator.get_iterator(random_seed=options.seed) output_path = os.path.abspath(os.path.join(options.experiment_path, 'parse.jsonl')) logger.info('Beginning.') logger.info('Writing output to = {}'.format(output_path)) f = open(output_path, 'w') with torch.no_grad(): for i, batch_map in tqdm(enumerate(batches)): sentences = batch_map['sentences'] batch_size = sentences.shape[0] length = sentences.shape[1] # Skip very short sentences. if length <= 2: continue _ = trainer.step(batch_map, train=False, compute_loss=False) trees = parse_predictor.parse_batch(batch_map) for ii, tr in enumerate(trees): example_id = batch_map['example_ids'][ii] s = [idx2word[idx] for idx in sentences[ii].tolist()] tr = replace_leaves(tr, s) if options.postprocess: tr = postprocess(tr, s) o = collections.OrderedDict(example_id=example_id, tree=tr) f.write(json.dumps(o) + '\n') f.close()
def run_train(options, train_iterator, trainer, validation_iterator): logger = get_logger() experiment_logger = ExperimentLogger() logger.info('Running train.') seeds = generate_seeds(options.max_epoch, options.seed) step = 0 # Added now idx2word = {v: k for k, v in train_iterator.word2idx.items()} parse_predictor = CKY(net=trainer.net.diora, word2idx=train_iterator.word2idx) # Added now for epoch, seed in zip(range(options.max_epoch), seeds): # --- Train--- # # Added now precision = 0 recall = 0 total_len = 0 count_des = 0 # Added now seed = seeds[epoch] logger.info('epoch={} seed={}'.format(epoch, seed)) def myiterator(): it = train_iterator.get_iterator(random_seed=seed) count = 0 for batch_map in it: # TODO: Skip short examples (optionally). if batch_map['length'] <= 2: continue yield count, batch_map count += 1 for batch_idx, batch_map in myiterator(): if options.finetune and step >= options.finetune_after: trainer.freeze_diora() result = trainer.step(batch_map) # Added now trainer.net.eval() sentences = batch_map['sentences'] trees = parse_predictor.parse_batch(batch_map) o_list = [] for ii, tr in enumerate(trees): example_id = batch_map['example_ids'][ii] s = [idx2word[idx] for idx in sentences[ii].tolist()] tr = replace_leaves(tr, s) o = dict(example_id=example_id, tree=tr) o_list.append(o["tree"]) # print(json.dumps(o)) # print(o["tree"]) # print(batch_map["parse_tree"][ii]) if isinstance(batch_map["parse_tree"][ii], str): parse_tree_tuple = str_to_tuple( batch_map["parse_tree"][ii]) else: parse_tree_tuple = batch_map["parse_tree"][ii] o_spans = tree_to_spans(o["tree"]) batch_spans = tree_to_spans(parse_tree_tuple[0]) p, r, t = precision_and_recall(batch_spans, o_spans) precision += p recall += r total_len += t # print(precision, recall, total_len) # print(precision / total_len, recall / total_len) # print((2*precision*recall)/(total_len*(precision+recall))) trainer.net.train() # Added now experiment_logger.record(result) if step % options.log_every_batch == 0: experiment_logger.log_batch(epoch, step, batch_idx, batch_size=options.batch_size) # -- Periodic Checkpoints -- # if not options.multigpu or options.local_rank == 0: if step % options.save_latest == 0 and step >= options.save_after: logger.info('Saving model (periodic).') trainer.save_model( os.path.join(options.experiment_path, 'model_periodic.pt')) save_experiment( os.path.join(options.experiment_path, 'experiment_periodic.json'), step) if step % options.save_distinct == 0 and step >= options.save_after: logger.info('Saving model (distinct).') trainer.save_model( os.path.join(options.experiment_path, 'model.step_{}.pt'.format(step))) save_experiment( os.path.join(options.experiment_path, 'experiment.step_{}.json'.format(step)), step) del result step += 1 # Added now print(precision, recall, total_len) print(precision / total_len, recall / total_len) print(count_des) # Added now experiment_logger.log_epoch(epoch, step) if options.max_step is not None and step >= options.max_step: logger.info('Max-Step={} Quitting.'.format(options.max_step)) sys.exit()