Beispiel #1
0
def run_parse(options, train_iterator, trainer, validation_iterator):
    logger = get_logger()

    validation_dataset = get_validation_dataset(options)
    validation_iterator = get_validation_iterator(options, validation_dataset)
    word2idx = validation_dataset['word2idx']
    embeddings = validation_dataset['embeddings']

    idx2word = {v: k for k, v in word2idx.items()}

    logger.info('Initializing model.')
    trainer = build_net(options, embeddings, validation_iterator)

    # Parse

    diora = trainer.net.diora

    ## Turn off outside pass.
    trainer.net.diora.outside = False

    ## Eval mode.
    trainer.net.eval()

    ## Topk predictor.
    parse_predictor = CKY(net=diora, word2idx=word2idx)

    batches = validation_iterator.get_iterator(random_seed=options.seed)

    logger.info('Beginning to parse.')

    with torch.no_grad():
        for i, batch_map in enumerate(batches):
            sentences = batch_map['sentences']
            batch_size = sentences.shape[0]
            length = sentences.shape[1]

            # Rather than skipping, just log the trees (they are trivially easy to find).
            if length <= 2:
                for i in range(batch_size):
                    example_id = batch_map['example_ids'][i]
                    tokens = sentences[i].tolist()
                    words = [idx2word[idx] for idx in tokens]
                    if length == 2:
                        o = dict(example_id=example_id, tree=(words[0], words[1]))
                    elif length == 1:
                        o = dict(example_id=example_id, tree=words[0])
                    print(json.dumps(o))
                continue

            _ = trainer.step(batch_map, train=False, compute_loss=False)

            trees = parse_predictor.parse_batch(batch_map)

            for ii, tr in enumerate(trees):
                example_id = batch_map['example_ids'][ii]
                s = [idx2word[idx] for idx in sentences[ii].tolist()]
                tr = replace_leaves(tr, s)
                o = dict(example_id=example_id, tree=tr)

                print(json.dumps(o))
Beispiel #2
0
    def forward(self, sentences, neg_samples, diora, info):
        batch_size, length = sentences.shape
        size = diora.outside_h.shape[-1]

        # Get the score for the ground truth tree.
        # gold_spans = self.makeLeftTree(info['spans'])
        gold_spans_r = self.makeRightTree(info['spans'])

        gold_scores = self.get_score_for_spans(sentences, diora.saved_scalars,
                                               gold_spans_r)

        # Get the score for maximal tree.
        parse_predictor = CKY(net=diora, word2idx=self.word2idx)
        max_trees = parse_predictor.parse_batch({'sentences': sentences})
        max_spans = [tree_to_spans(x) for x in max_trees]

        max_scores = self.get_score_for_spans(sentences, diora.saved_scalars,
                                              max_spans)

        loss = max_scores - gold_scores + self.margin

        loss = loss.sum().view(1) / batch_size

        ret = dict(semi_supervised_parsing_loss=loss)

        return loss, ret
Beispiel #3
0
    def forward(self, sentences, neg_samples, diora, info):
        batch_size, length = sentences.shape
        size = diora.outside_h.shape[-1]

        # TODO for semi-supervised create a "pseudo-tree" using external annotation (i.e. entity boundaries).

        # Get the score for the ground truth tree.
        gold_scores = self.get_score_for_spans(sentences, diora.saved_scalars, info['spans'])

        # Get the score for maximal tree.
        parse_predictor = CKY(net=diora, word2idx=self.word2idx)
        max_trees = parse_predictor.parse_batch({'sentences': sentences})
        max_spans = [tree_to_spans(x) for x in max_trees]
        max_scores = self.get_score_for_spans(sentences, diora.saved_scalars, max_spans)

        loss = max_scores - gold_scores + self.margin

        loss = loss.sum().view(1) / batch_size

        ret = dict(semi_supervised_parsing_loss=loss)

        return loss, ret
Beispiel #4
0
    def forward_(self, sentences, neg_samples, diora, info):
        batch_size, length = sentences.shape
        size = diora.outside_h.shape[-1]

        # Get the score for the ground truth tree.
        gold_spans = self.makeLeftTree(info['spans'])
        gold_spans_r = self.makeRightTree(info['spans'])
        gold_scores = self.get_score_for_spans_modified(
            sentences, diora.saved_scalars, gold_spans, info['spans'])
        gold_scores_r = self.get_score_for_spans_modified(
            sentences, diora.saved_scalars, gold_spans_r, info['spans'])

        # print("gold score", gold_scores)
        # print("right gold score", gold_scores_r)

        #print('info spans', info['spans'])
        #print('ner spans', gold_spans)

        # Get the score for maximal tree.
        parse_predictor = CKY(net=diora, word2idx=self.word2idx)
        max_trees = parse_predictor.parse_batch({'sentences': sentences})
        max_spans = [tree_to_spans(x) for x in max_trees]

        roots, diora_spans = self.findClosestParent(max_spans, info['spans'])
        # print('paresed spans', max_spans)
        # print('closest subtree spans', diora_spans)
        # print('closest roots', roots)

        gold_scores = self.get_score_for_spans_modified(
            sentences, diora.saved_scalars, gold_spans_r, info['spans'])
        max_scores = self.get_score_for_spans_modified(sentences,
                                                       diora.saved_scalars,
                                                       diora_spans, roots)

        total_loss = 0
        # print('gold scores', gold_scores)
        # print('max scores', max_scores)

        for dp in range(len(gold_scores)):
            gold_score_data = gold_scores[dp]
            max_score_data = max_scores[dp]
            # print(gold_score_data, max_score_data)
            loss = 0
            for i in range(len(gold_score_data)):
                if int(info['spans'][dp][i][0]) == int(
                        roots[dp][i][0]) and int(
                            info['spans'][dp][i][1]) == int(roots[dp][i][1]):
                    # print(info['spans'][dp][i], roots[dp][i])
                    continue
                else:
                    loss += max_score_data[i] - gold_score_data[i] + self.margin
            total_loss += loss
        #loss = max_scores - gold_scores + self.margin

        #loss = loss.sum().view(1) / batch_size
        loss = torch.tensor(total_loss / batch_size, requires_grad=True)
        # print("semi_supervised_parsing_loss", loss)
        ret = dict(semi_supervised_parsing_loss=loss)
        # print("-------------")
        # print('loss', loss)
        return loss, ret
Beispiel #5
0
def run(options):
    logger = get_logger()

    validation_dataset = get_validation_dataset(options)
    #print(validation_dataset['sentence1'][0],validation_dataset['example_ids'][0])
    validation_iterator = get_validation_iterator(options, validation_dataset)
    word2idx = validation_dataset['word2idx']
    embeddings = validation_dataset['embeddings']

    idx2word = {v: k for k, v in word2idx.items()}

    logger.info('Initializing model.')
    trainer = build_net(options, embeddings, validation_iterator)

    # Parse

    diora = trainer.net.encoder

    ## Monkey patch parsing specific methods.
    override_init_with_batch(diora)
    override_inside_hook(diora)

    ## Turn off outside pass.
    #trainer.net.encoder.outside = False

    ## Eval mode.
    trainer.net.eval()

    ## Parse predictor.
    parse_predictor = CKY(net=diora, word2idx=word2idx)

    batches = validation_iterator.get_iterator(random_seed=options.seed)

    output_path1 = os.path.abspath(os.path.join(options.experiment_path, 'parse_mnli1.jsonl'))
    output_path2 = os.path.abspath(os.path.join(options.experiment_path, 'parse_mnli2.jsonl'))

    logger.info('Beginning.')
    logger.info('Writing output to = {}'.format(output_path1))
    logger.info('Writing output to = {}'.format(output_path2))

    f = open(output_path1, 'w')

    with torch.no_grad():
        for i, batch_map in tqdm(enumerate(batches)):
            #print(batch_map.keys())
            sentences1 = batch_map['sentences_1']
            sentences2 = batch_map['sentences_2']
            #print(sentences.shape)
            batch_size = sentences1.shape[0]
            length = sentences1.shape[1]

            # Skip very short sentences.
            if length <= 2:
                continue

            _ = trainer.step(batch_map, train=False, compute_loss=False)

            trees1 = parse_predictor.parse_batch(sentences1)
            trees2 = parse_predictor.parse_batch(sentences2)
            #print(list(zip(trees1,trees2)))
            for ii,tree in enumerate(list(zip(trees1,trees2))):
                tr1,tr2 = tree[0],tree[1]
                example_id = batch_map['example_ids'][ii]
                #print(batch_map['example_ids'])
                s1 = [idx2word[idx] for idx in sentences1[ii].tolist()]
                s2 = [idx2word[idx] for idx in sentences2[ii].tolist()]
                tr1 = replace_leaves(tr1, s1)
                tr2 = replace_leaves(tr2, s2)
                if options.postprocess:
                    tr = postprocess(tr, s1)
                o = collections.OrderedDict(example_id=example_id, sentence1=tr1,sentence2=tr2)
                #print(o)
                #exit()

                f.write(json.dumps(o) + '\n')
  
    f.close()
Beispiel #6
0
 def init(self, options):
     if options.parse_mode == 'latent':
         self.parse_predictor = CKY(net=self.diora, word2idx=self.word2idx)
         ## Monkey patch parsing specific methods.
         override_init_with_batch(self.diora)
         override_inside_hook(self.diora)
Beispiel #7
0
def run(options):
    logger = get_logger()

    validation_dataset = get_validation_dataset(options)
    validation_iterator = get_validation_iterator(options, validation_dataset)
    word2idx = validation_dataset['word2idx']
    embeddings = validation_dataset['embeddings']

    idx2word = {v: k for k, v in word2idx.items()}

    logger.info('Initializing model.')
    trainer = build_net(options, embeddings, validation_iterator)

    # Parse

    diora = trainer.net.diora

    ## Monkey patch parsing specific methods.
    override_init_with_batch(diora)
    override_inside_hook(diora)

    ## Turn off outside pass.
    trainer.net.diora.outside = False

    ## Eval mode.
    trainer.net.eval()

    ## Parse predictor.
    parse_predictor = CKY(net=diora, word2idx=word2idx)

    batches = validation_iterator.get_iterator(random_seed=options.seed)

    output_path = os.path.abspath(os.path.join(options.experiment_path, 'parse.jsonl'))

    logger.info('Beginning.')
    logger.info('Writing output to = {}'.format(output_path))

    f = open(output_path, 'w')

    with torch.no_grad():
        for i, batch_map in tqdm(enumerate(batches)):
            sentences = batch_map['sentences']
            batch_size = sentences.shape[0]
            length = sentences.shape[1]

            # Skip very short sentences.
            if length <= 2:
                continue

            _ = trainer.step(batch_map, train=False, compute_loss=False)

            trees = parse_predictor.parse_batch(batch_map)

            for ii, tr in enumerate(trees):
                example_id = batch_map['example_ids'][ii]
                s = [idx2word[idx] for idx in sentences[ii].tolist()]
                tr = replace_leaves(tr, s)
                if options.postprocess:
                    tr = postprocess(tr, s)
                o = collections.OrderedDict(example_id=example_id, tree=tr)

                f.write(json.dumps(o) + '\n')

    f.close()
Beispiel #8
0
def run_train(options, train_iterator, trainer, validation_iterator):
    logger = get_logger()
    experiment_logger = ExperimentLogger()

    logger.info('Running train.')

    seeds = generate_seeds(options.max_epoch, options.seed)

    step = 0

    # Added now
    idx2word = {v: k for k, v in train_iterator.word2idx.items()}
    parse_predictor = CKY(net=trainer.net.diora,
                          word2idx=train_iterator.word2idx)
    # Added now

    for epoch, seed in zip(range(options.max_epoch), seeds):
        # --- Train--- #

        # Added now
        precision = 0
        recall = 0
        total_len = 0
        count_des = 0
        # Added now

        seed = seeds[epoch]

        logger.info('epoch={} seed={}'.format(epoch, seed))

        def myiterator():
            it = train_iterator.get_iterator(random_seed=seed)

            count = 0

            for batch_map in it:
                # TODO: Skip short examples (optionally).
                if batch_map['length'] <= 2:
                    continue

                yield count, batch_map
                count += 1

        for batch_idx, batch_map in myiterator():
            if options.finetune and step >= options.finetune_after:
                trainer.freeze_diora()

            result = trainer.step(batch_map)

            # Added now
            trainer.net.eval()
            sentences = batch_map['sentences']
            trees = parse_predictor.parse_batch(batch_map)
            o_list = []
            for ii, tr in enumerate(trees):
                example_id = batch_map['example_ids'][ii]
                s = [idx2word[idx] for idx in sentences[ii].tolist()]
                tr = replace_leaves(tr, s)
                o = dict(example_id=example_id, tree=tr)
                o_list.append(o["tree"])
                # print(json.dumps(o))
                # print(o["tree"])
                # print(batch_map["parse_tree"][ii])
                if isinstance(batch_map["parse_tree"][ii], str):
                    parse_tree_tuple = str_to_tuple(
                        batch_map["parse_tree"][ii])
                else:
                    parse_tree_tuple = batch_map["parse_tree"][ii]

                o_spans = tree_to_spans(o["tree"])
                batch_spans = tree_to_spans(parse_tree_tuple[0])

                p, r, t = precision_and_recall(batch_spans, o_spans)
                precision += p
                recall += r
                total_len += t

                # print(precision, recall, total_len)
                # print(precision / total_len, recall / total_len)
                # print((2*precision*recall)/(total_len*(precision+recall)))

            trainer.net.train()
            # Added now

            experiment_logger.record(result)

            if step % options.log_every_batch == 0:
                experiment_logger.log_batch(epoch,
                                            step,
                                            batch_idx,
                                            batch_size=options.batch_size)

            # -- Periodic Checkpoints -- #

            if not options.multigpu or options.local_rank == 0:
                if step % options.save_latest == 0 and step >= options.save_after:
                    logger.info('Saving model (periodic).')
                    trainer.save_model(
                        os.path.join(options.experiment_path,
                                     'model_periodic.pt'))
                    save_experiment(
                        os.path.join(options.experiment_path,
                                     'experiment_periodic.json'), step)

                if step % options.save_distinct == 0 and step >= options.save_after:
                    logger.info('Saving model (distinct).')
                    trainer.save_model(
                        os.path.join(options.experiment_path,
                                     'model.step_{}.pt'.format(step)))
                    save_experiment(
                        os.path.join(options.experiment_path,
                                     'experiment.step_{}.json'.format(step)),
                        step)

            del result

            step += 1
        # Added now
        print(precision, recall, total_len)
        print(precision / total_len, recall / total_len)
        print(count_des)
        # Added now
        experiment_logger.log_epoch(epoch, step)

        if options.max_step is not None and step >= options.max_step:
            logger.info('Max-Step={} Quitting.'.format(options.max_step))
            sys.exit()