Beispiel #1
0
def run(options):
    logger = get_logger()
    experiment_logger = ExperimentLogger()

    train_dataset, validation_dataset = get_train_and_validation(options)
    train_iterator = get_train_iterator(options, train_dataset)
    validation_iterator = get_validation_iterator(options, validation_dataset)
    embeddings = train_dataset['embeddings']

    logger.info('Initializing model.')
    trainer = build_net(options, embeddings, validation_iterator)
    logger.info('Model:')
    for name, p in trainer.net.named_parameters():
        logger.info('{} {}'.format(name, p.shape))

    if options.save_init:
        logger.info('Saving model (init).')
        trainer.save_model(
            os.path.join(options.experiment_path, 'model_init.pt'))

    if options.parse_only:
        run_parse(options, train_iterator, trainer, validation_iterator)
        sys.exit()

    run_train(options, train_iterator, trainer, validation_iterator)
Beispiel #2
0
def run_train(options, train_iterator, trainer, validation_iterator):
    logger = get_logger()
    experiment_logger = ExperimentLogger()

    logger.info('Running train.')

    seeds = generate_seeds(options.max_epoch, options.seed)

    step = 0

    for epoch, seed in zip(range(options.max_epoch), seeds):
        # --- Train--- #

        seed = seeds[epoch]

        logger.info('epoch={} seed={}'.format(epoch, seed))

        def myiterator():
            it = train_iterator.get_iterator(random_seed=seed)

            count = 0

            for batch_map in it:
                # TODO: Skip short examples (optionally).
                if batch_map['length'] <= 2:
                    continue

                yield count, batch_map
                count += 1

        for batch_idx, batch_map in myiterator():
            if options.finetune and step >= options.finetune_after:
                trainer.freeze_diora()

            result = trainer.step(batch_map)

            experiment_logger.record(result)

            if step % options.log_every_batch == 0:
                experiment_logger.log_batch(epoch,
                                            step,
                                            batch_idx,
                                            batch_size=options.batch_size)

            # -- Periodic Checkpoints -- #

            if not options.multigpu or options.local_rank == 0:
                if step % options.save_latest == 0 and step >= options.save_after:
                    logger.info('Saving model (periodic).')
                    trainer.save_model(
                        os.path.join(options.experiment_path,
                                     'model_periodic.pt'))
                    save_experiment(
                        os.path.join(options.experiment_path,
                                     'experiment_periodic.json'), step)

                if step % options.save_distinct == 0 and step >= options.save_after:
                    logger.info('Saving model (distinct).')
                    trainer.save_model(
                        os.path.join(options.experiment_path,
                                     'model.step_{}.pt'.format(step)))
                    save_experiment(
                        os.path.join(options.experiment_path,
                                     'experiment.step_{}.json'.format(step)),
                        step)

            del result

            step += 1

        experiment_logger.log_epoch(epoch, step)

        if options.max_step is not None and step >= options.max_step:
            logger.info('Max-Step={} Quitting.'.format(options.max_step))
            sys.exit()
Beispiel #3
0
def run_train(options, train_iterator, trainer, validation_iterator):
    logger = get_logger()
    experiment_logger = ExperimentLogger()

    logger.info('Running train.')

    seeds = generate_seeds(options.max_epoch, options.seed)

    step = 0

    # Added now
    idx2word = {v: k for k, v in train_iterator.word2idx.items()}
    parse_predictor = CKY(net=trainer.net.diora,
                          word2idx=train_iterator.word2idx)
    # Added now

    for epoch, seed in zip(range(options.max_epoch), seeds):
        # --- Train--- #

        # Added now
        precision = 0
        recall = 0
        total_len = 0
        count_des = 0
        # Added now

        seed = seeds[epoch]

        logger.info('epoch={} seed={}'.format(epoch, seed))

        def myiterator():
            it = train_iterator.get_iterator(random_seed=seed)

            count = 0

            for batch_map in it:
                # TODO: Skip short examples (optionally).
                if batch_map['length'] <= 2:
                    continue

                yield count, batch_map
                count += 1

        for batch_idx, batch_map in myiterator():
            if options.finetune and step >= options.finetune_after:
                trainer.freeze_diora()

            result = trainer.step(batch_map)

            # Added now
            trainer.net.eval()
            sentences = batch_map['sentences']
            trees = parse_predictor.parse_batch(batch_map)
            o_list = []
            for ii, tr in enumerate(trees):
                example_id = batch_map['example_ids'][ii]
                s = [idx2word[idx] for idx in sentences[ii].tolist()]
                tr = replace_leaves(tr, s)
                o = dict(example_id=example_id, tree=tr)
                o_list.append(o["tree"])
                # print(json.dumps(o))
                # print(o["tree"])
                # print(batch_map["parse_tree"][ii])
                if isinstance(batch_map["parse_tree"][ii], str):
                    parse_tree_tuple = str_to_tuple(
                        batch_map["parse_tree"][ii])
                else:
                    parse_tree_tuple = batch_map["parse_tree"][ii]

                o_spans = tree_to_spans(o["tree"])
                batch_spans = tree_to_spans(parse_tree_tuple[0])

                p, r, t = precision_and_recall(batch_spans, o_spans)
                precision += p
                recall += r
                total_len += t

                # print(precision, recall, total_len)
                # print(precision / total_len, recall / total_len)
                # print((2*precision*recall)/(total_len*(precision+recall)))

            trainer.net.train()
            # Added now

            experiment_logger.record(result)

            if step % options.log_every_batch == 0:
                experiment_logger.log_batch(epoch,
                                            step,
                                            batch_idx,
                                            batch_size=options.batch_size)

            # -- Periodic Checkpoints -- #

            if not options.multigpu or options.local_rank == 0:
                if step % options.save_latest == 0 and step >= options.save_after:
                    logger.info('Saving model (periodic).')
                    trainer.save_model(
                        os.path.join(options.experiment_path,
                                     'model_periodic.pt'))
                    save_experiment(
                        os.path.join(options.experiment_path,
                                     'experiment_periodic.json'), step)

                if step % options.save_distinct == 0 and step >= options.save_after:
                    logger.info('Saving model (distinct).')
                    trainer.save_model(
                        os.path.join(options.experiment_path,
                                     'model.step_{}.pt'.format(step)))
                    save_experiment(
                        os.path.join(options.experiment_path,
                                     'experiment.step_{}.json'.format(step)),
                        step)

            del result

            step += 1
        # Added now
        print(precision, recall, total_len)
        print(precision / total_len, recall / total_len)
        print(count_des)
        # Added now
        experiment_logger.log_epoch(epoch, step)

        if options.max_step is not None and step >= options.max_step:
            logger.info('Max-Step={} Quitting.'.format(options.max_step))
            sys.exit()