def setUp(self):
     self.args = parser_utils.get_full_config("test/run_args/test_model_training.json")
     train_utils.setup_train_heads_and_eval_slices(self.args)
     self.word_symbols = data_utils.load_wordsymbols(self.args.data_config)
     self.entity_symbols = EntitySymbols(os.path.join(
         self.args.data_config.entity_dir, self.args.data_config.entity_map_dir),
         alias_cand_map_file=self.args.data_config.alias_cand_map)
     slices = WikiSlices(
         args=self.args,
         use_weak_label=False,
         input_src=os.path.join(self.args.data_config.data_dir, "train.jsonl"),
         dataset_name=os.path.join(self.args.data_config.data_dir, data_utils.generate_save_data_name(
             data_args=self.args.data_config, use_weak_label=True, split_name="slice_train")),
         is_writer=True,
         distributed=self.args.run_config.distributed,
         dataset_is_eval=False
     )
     self.data = WikiDataset(
         args=self.args,
         use_weak_label=False,
         input_src=os.path.join(self.args.data_config.data_dir, "train.jsonl"),
         dataset_name=os.path.join(self.args.data_config.data_dir, data_utils.generate_save_data_name(
             data_args=self.args.data_config, use_weak_label=False, split_name="train")),
         is_writer=True,
         distributed=self.args.run_config.distributed,
         word_symbols=self.word_symbols,
         entity_symbols=self.entity_symbols,
         slice_dataset=slices,
         dataset_is_eval=False
     )
     self.trainer = Trainer(self.args, self.entity_symbols, self.word_symbols)
Exemple #2
0
 def setUp(self) -> None:
     self.args = parser_utils.get_full_config(
         "test/run_args/test_embeddings.json")
     self.word_symbols = data_utils.load_wordsymbols(self.args.data_config)
     self.entity_symbols = EntitySymbolsSubclass()
     self.word_emb = WordEmbeddingMock(
         self.args.data_config.word_embedding, self.args, self.word_symbols)
     self.title_emb = AvgTitleEmb(self.args,
         self.args.data_config.ent_embeddings[1],
         "cpu", self.entity_symbols, self.word_symbols,
         word_emb=self.word_emb, key="avg_title")
Exemple #3
0
 def setUp(self) -> None:
     self.args = parser_utils.get_full_config(
         "test/run_args/test_embeddings.json")
     self.args.data_config.ent_embeddings = [
         DottedDict(
         {
             "key": "learned1",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
         DottedDict(
         {
             "key": "learned2",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
         DottedDict(
         {
             "key": "learned3",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
         DottedDict(
         {
             "key": "learned4",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
         DottedDict(
         {
             "key": "learned5",
             "load_class": "LearnedEntityEmb",
             "args": {
                 "learned_embedding_size": 5,
                 "tail_init": False
             }
         }),
     ]
     self.word_symbols = data_utils.load_wordsymbols(self.args.data_config)
     self.entity_symbols = EntitySymbolsSubclass()
Exemple #4
0
    def __init__(self,
                 config_args,
                 device='cuda',
                 max_alias_len=6,
                 cand_map=None,
                 threshold=0.0):
        self.args = config_args
        self.device = device
        self.entity_db = EntitySymbols(
            os.path.join(self.args.data_config.entity_dir,
                         self.args.data_config.entity_map_dir),
            alias_cand_map_file=self.args.data_config.alias_cand_map)
        self.word_db = data_utils.load_wordsymbols(self.args.data_config,
                                                   is_writer=True,
                                                   distributed=False)
        self.model = self._load_model()
        self.max_alias_len = max_alias_len
        if cand_map is None:
            alias_map = self.entity_db._alias2qids
        else:
            alias_map = ujson.load(open(cand_map))
        self.all_aliases_trie = get_all_aliases(alias_map,
                                                logger=logging.getLogger())
        self.alias_table = AliasEntityTable(args=self.args,
                                            entity_symbols=self.entity_db)

        # minimum probability of prediction to return mention
        self.threshold = threshold

        # get batch_on_the_fly embeddings _and_ the batch_prep embeddings
        self.batch_on_the_fly_embs = {}
        for i, emb in enumerate(self.args.data_config.ent_embeddings):
            if 'batch_prep' in emb and emb['batch_prep'] is True:
                self.args.data_config.ent_embeddings[i][
                    'batch_on_the_fly'] = True
                del self.args.data_config.ent_embeddings[i]['batch_prep']
            if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True:
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                try:
                    self.batch_on_the_fly_embs[emb.key] = getattr(
                        mod, load_class)(main_args=self.args,
                                         emb_args=emb['args'],
                                         entity_symbols=self.entity_db,
                                         model_device=None,
                                         word_symbols=None)
                except AttributeError as e:
                    print(
                        f'No prep method found for {emb.load_class} with error {e}'
                    )
                except Exception as e:
                    print("ERROR", e)
Exemple #5
0
    def test_edge_case(self):
        # Edge-case lengths

        # Test maximum sequence length
        max_aliases = 30
        max_seq_len = 3

        # Manual data
        sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5'
        aliases = ["The big alias1", "multi word alias2 and alias3"]
        aliases_to_predict = [0, 1]
        spans = [[0, 3], [8, 13]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "The big alias1".split(), "multi word alias2".split()
        ]
        true_spans_arr = [[[0, 3]], [[0, 5]]]
        true_alias_to_predict_arr = [[0], [0]]
        true_aliases_arr = [["The big alias1"],
                            ["multi word alias2 and alias3"]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Exemple #6
0
    def test_split_sentence_alias_to_predict(self):
        # No splitting nut change in aliases to predict...nothing should change
        max_aliases = 30
        max_seq_len = 24

        # Manually created data
        sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5'
        aliases = ["The big", "alias3", "alias5"]
        aliases_to_predict = [0, 1]
        spans = [[0, 2], [12, 13], [20, 21]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # Truth data
        true_phrase_arr = [
            "The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5 <pad> <pad> <pad>"
            .split(" ")
        ]
        true_spans_arr = [[[0, 2], [12, 13], [20, 21]]]
        true_alias_to_predict_arr = [[0, 1]]
        true_aliases_arr = [["The big", "alias3", "alias5"]]

        assert len(idxs_arr) == 1
        assert len(aliases_to_predict_arr) == 1
        assert len(spans_arr) == 1
        assert len(phrase_tokens_arr) == 1
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Exemple #7
0
    def test_split_sentence_max_aliases(self):
        # Test if the sentence splits correctly when max_aliases is less than the number of aliases
        max_aliases = 2
        max_seq_len = 24

        # Manually created data
        sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5'
        aliases = ["The big", "alias3", "alias5"]
        aliases_to_predict = [0, 1, 2]
        spans = [[0, 2], [12, 13], [20, 21]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5 <pad> <pad> <pad>"
            .split(" ")
        ] * 2
        true_spans_arr = [[[0, 2], [12, 13]], [[20, 21]]]
        true_alias_to_predict_arr = [[0, 1], [0]]
        true_aliases_arr = [["The big", "alias3"], ["alias5"]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Exemple #8
0
 def setUp(self) -> None:
     self.args = parser_utils.get_full_config(
         "test/run_args/test_embeddings.json")
     self.word_symbols = data_utils.load_wordsymbols(self.args.data_config)
     self.entity_symbols = EntitySymbolsSubclass()
Exemple #9
0
 def setUp(self) -> None:
     #TODO: replace with custom vocab file and not GloVE
     self.args = parser_utils.get_full_config(
         "test/run_args/test_embeddings.json")
     self.word_symbols = data_utils.load_wordsymbols(self.args.data_config)
Exemple #10
0
def model_eval(args, mode, is_writer, logger, world_size=1, rank=0):
    assert args.run_config.init_checkpoint != "", "You can't have an empty model file to do eval"
    # this is in main but call again in case eval is called directly
    train_utils.setup_train_heads_and_eval_slices(args)
    train_utils.setup_run_folders(args, mode)

    word_symbols = data_utils.load_wordsymbols(
        args.data_config, is_writer, distributed=args.run_config.distributed)
    logger.info(f"Loading entity_symbols...")
    entity_symbols = EntitySymbols(
        load_dir=os.path.join(args.data_config.entity_dir,
                              args.data_config.entity_map_dir),
        alias_cand_map_file=args.data_config.alias_cand_map)
    logger.info(
        f"Loaded entity_symbols with {entity_symbols.num_entities} entities.")
    eval_slice_names = args.run_config.eval_slices
    test_dataset_collection = {}
    test_slice_dataset = data_utils.create_slice_dataset(
        args, args.data_config.test_dataset, is_writer, dataset_is_eval=True)
    test_dataset = data_utils.create_dataset(args,
                                             args.data_config.test_dataset,
                                             is_writer,
                                             word_symbols,
                                             entity_symbols,
                                             slice_dataset=test_slice_dataset,
                                             dataset_is_eval=True)
    test_dataloader, test_sampler = data_utils.create_dataloader(
        args,
        test_dataset,
        eval_slice_dataset=test_slice_dataset,
        batch_size=args.run_config.eval_batch_size)
    dataset_collection = DatasetCollection(args.data_config.test_dataset,
                                           args.data_config.test_dataset.file,
                                           test_dataset, test_dataloader,
                                           test_slice_dataset, test_sampler)
    test_dataset_collection[
        args.data_config.test_dataset.file] = dataset_collection

    trainer = Trainer(args,
                      entity_symbols,
                      word_symbols,
                      resume_model_file=args.run_config.init_checkpoint,
                      eval_slice_names=eval_slice_names,
                      model_eval=True)

    # Run evaluation numbers without dumping predictions (quick, batched)
    if mode == 'eval':
        status_reporter = StatusReporter(args,
                                         logger,
                                         is_writer,
                                         max_epochs=None,
                                         total_steps_per_epoch=None,
                                         is_eval=True)
        # results are written to json file
        for test_data_file in test_dataset_collection:
            logger.info(
                f"************************RUNNING EVAL {test_data_file}************************"
            )
            test_dataloader = test_dataset_collection[
                test_data_file].data_loader
            # True is for if the batch is test or not, None is for the global step
            eval_utils.run_batched_eval(args=args,
                                        is_test=True,
                                        global_step=None,
                                        logger=logger,
                                        trainer=trainer,
                                        dataloader=test_dataloader,
                                        status_reporter=status_reporter,
                                        file=test_data_file)

    elif mode == 'dump_preds' or mode == 'dump_embs':
        # get predictions and optionally dump the corresponding contextual entity embeddings
        # TODO: support dumping ids for other embeddings as well (static entity embeddings, type embeddings, relation embeddings)
        # TODO: remove collection abstraction
        for test_data_file in test_dataset_collection:
            logger.info(
                f"************************DUMPING PREDICTIONS FOR {test_data_file}************************"
            )
            test_dataloader = test_dataset_collection[
                test_data_file].data_loader
            pred_file, emb_file = eval_utils.run_dump_preds(
                args=args,
                entity_symbols=entity_symbols,
                test_data_file=test_data_file,
                logger=logger,
                trainer=trainer,
                dataloader=test_dataloader,
                dump_embs=(mode == 'dump_embs'))
            return pred_file, emb_file
    return
Exemple #11
0
def train(args, is_writer, logger, world_size, rank):
    # This is main but call again in case train is called directly
    train_utils.setup_train_heads_and_eval_slices(args)
    train_utils.setup_run_folders(args, "train")

    # Load word symbols (like tokenizers) and entity symbols (aka entity profiles)
    word_symbols = data_utils.load_wordsymbols(
        args.data_config, is_writer, distributed=args.run_config.distributed)
    logger.info(f"Loading entity_symbols...")
    entity_symbols = EntitySymbols(
        load_dir=os.path.join(args.data_config.entity_dir,
                              args.data_config.entity_map_dir),
        alias_cand_map_file=args.data_config.alias_cand_map)
    logger.info(
        f"Loaded entity_symbols with {entity_symbols.num_entities} entities.")
    # Get train dataset
    train_slice_dataset = data_utils.create_slice_dataset(
        args, args.data_config.train_dataset, is_writer, dataset_is_eval=False)
    train_dataset = data_utils.create_dataset(
        args,
        args.data_config.train_dataset,
        is_writer,
        word_symbols,
        entity_symbols,
        slice_dataset=train_slice_dataset,
        dataset_is_eval=False)
    train_dataloader, train_sampler = data_utils.create_dataloader(
        args,
        train_dataset,
        eval_slice_dataset=None,
        world_size=world_size,
        rank=rank,
        batch_size=args.train_config.batch_size)

    # Repeat for dev
    dev_dataset_collection = {}
    dev_slice_dataset = data_utils.create_slice_dataset(
        args, args.data_config.dev_dataset, is_writer, dataset_is_eval=True)
    dev_dataset = data_utils.create_dataset(args,
                                            args.data_config.dev_dataset,
                                            is_writer,
                                            word_symbols,
                                            entity_symbols,
                                            slice_dataset=dev_slice_dataset,
                                            dataset_is_eval=True)
    dev_dataloader, dev_sampler = data_utils.create_dataloader(
        args,
        dev_dataset,
        eval_slice_dataset=dev_slice_dataset,
        batch_size=args.run_config.eval_batch_size)
    dataset_collection = DatasetCollection(args.data_config.dev_dataset,
                                           args.data_config.dev_dataset.file,
                                           dev_dataset, dev_dataloader,
                                           dev_slice_dataset, dev_sampler)
    dev_dataset_collection[
        args.data_config.dev_dataset.file] = dataset_collection

    eval_slice_names = args.run_config.eval_slices

    total_steps_per_epoch = len(train_dataloader)
    # Create trainer---model, optimizer, and scorer
    trainer = Trainer(args,
                      entity_symbols,
                      word_symbols,
                      total_steps_per_epoch=total_steps_per_epoch,
                      eval_slice_names=eval_slice_names,
                      resume_model_file=args.run_config.init_checkpoint)

    # Set up epochs and intervals for saving and evaluating
    max_epochs = int(args.run_config.max_epochs)
    eval_steps = int(args.run_config.eval_steps)
    log_steps = int(args.run_config.log_steps)
    save_steps = max(int(args.run_config.save_every_k_eval * eval_steps), 1)
    logger.info(
        f"Eval steps {eval_steps}, Log steps {log_steps}, Save steps {save_steps}, Total training examples per epoch {len(train_dataset)}"
    )
    status_reporter = StatusReporter(args,
                                     logger,
                                     is_writer,
                                     max_epochs,
                                     total_steps_per_epoch,
                                     is_eval=False)
    global_step = 0
    for epoch in range(trainer.start_epoch, trainer.start_epoch + max_epochs):
        # this is to fix having to save/restore the RNG state for checkpointing
        torch.manual_seed(args.train_config.seed + epoch)
        np.random.seed(args.train_config.seed + epoch)
        if args.run_config.distributed:
            # for determinism across runs https://github.com/pytorch/examples/issues/501
            train_sampler.set_epoch(epoch)

        start_time_load = time.time()
        for i, batch in enumerate(train_dataloader):
            load_time = time.time() - start_time_load
            start_time = time.time()
            _, loss_pack, _, _ = trainer.update(batch)
            # Log progress
            if (global_step + 1) % log_steps == 0:
                duration = time.time() - start_time
                status_reporter.step_status(epoch=epoch,
                                            step=global_step,
                                            loss_pack=loss_pack,
                                            time=duration,
                                            load_time=load_time,
                                            lr=trainer.get_lr())
            # Save model
            if (global_step + 1) % save_steps == 0 and is_writer:
                logger.info("Saving model...")
                trainer.save(save_dir=train_utils.get_save_folder(
                    args.run_config),
                             epoch=epoch,
                             step=global_step,
                             step_in_batch=i,
                             suffix=args.run_config.model_suffix)
            # Run evaluation
            if (global_step + 1) % eval_steps == 0:
                eval_utils.run_eval_all_dev_sets(args, global_step,
                                                 dev_dataset_collection,
                                                 logger, status_reporter,
                                                 trainer)
            if args.run_config.distributed:
                dist.barrier()
            global_step += 1
            # Time loading new batch
            start_time_load = time.time()
        ######### END OF EPOCH
        if is_writer:
            logger.info(f"Saving model end of epoch {epoch}...")
            trainer.save(save_dir=train_utils.get_save_folder(args.run_config),
                         epoch=epoch,
                         step=global_step,
                         step_in_batch=i,
                         end_of_epoch=True,
                         suffix=args.run_config.model_suffix)
        # Always run eval when saving -- if this coincided with eval_step, then don't need to rerun eval
        if (global_step + 1) % eval_steps != 0:
            eval_utils.run_eval_all_dev_sets(args, global_step,
                                             dev_dataset_collection, logger,
                                             status_reporter, trainer)
    if is_writer:
        logger.info("Saving model...")
        trainer.save(save_dir=train_utils.get_save_folder(args.run_config),
                     epoch=epoch,
                     step=global_step,
                     step_in_batch=i,
                     end_of_epoch=True,
                     last_epoch=True,
                     suffix=args.run_config.model_suffix)
    if args.run_config.distributed:
        dist.barrier()
Exemple #12
0
    def test_real_cases_bert(self):
        # Example 1
        max_aliases = 10
        max_seq_len = 100

        # Manual data
        sentence = "The guest roster for O'Brien 's final show on January 22\u2014 Tom Hanks , Steve Carell and original first guest Will Ferrell \u2014was regarded by O'Brien as a `` dream lineup '' ; in addition , Neil Young performed his song `` Long May You Run `` and , as the show closed , was joined by Beck , Ferrell ( dressed as Ronnie Van Zant ) , Billy Gibbons , Ben Harper , O'Brien , Viveca Paulin , and The Tonight Show Band to perform the Lynyrd Skynyrd song `` Free Bird `` ."
        aliases = [
            "tom hanks", "steve carell", "will ferrell", "neil young",
            "long may you run", "beck", "ronnie van zant", "billy gibbons",
            "ben harper", "viveca paulin", "lynyrd skynyrd", "free bird"
        ]
        spans = [[11, 13], [14, 16], [20, 22], [36, 38], [42, 46], [57, 58],
                 [63, 66], [68, 70], [71, 73], [76, 78], [87, 89], [91, 93]]
        aliases_to_predict = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

        # Truth
        true_phrase_arr = [
            [
                '[CLS]', 'The', 'guest', 'roster', 'for', 'O', "'", 'Brien',
                "'", 's', 'final', 'show', 'on', 'January', '22', '—', 'Tom',
                'Hank', '##s', ',', 'Steve', 'Care', '##ll', 'and', 'original',
                'first', 'guest', 'Will', 'Fe', '##rrell', '—', 'was',
                'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`', '`',
                'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',',
                'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long',
                'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show',
                'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe',
                '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant',
                ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O',
                "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',',
                '[SEP]'
            ],
            [
                '[CLS]', 'The', 'guest', 'roster', 'for', 'O', "'", 'Brien',
                "'", 's', 'final', 'show', 'on', 'January', '22',
                '—', 'Tom', 'Hank', '##s', ',', 'Steve', 'Care', '##ll', 'and',
                'original', 'first', 'guest', 'Will', 'Fe', '##rrell', '—',
                'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`',
                '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',',
                'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long',
                'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show',
                'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe',
                '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant',
                ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O',
                "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',',
                '[SEP]'
            ],
            [
                '[CLS]', 'original', 'first', 'guest', 'Will', 'Fe', '##rrell',
                '—', 'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a',
                '`', '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition',
                ',', 'Neil', 'Young', 'performed', 'his', 'song', '`', '`',
                'Long', 'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the',
                'show', 'closed', ',', 'was', 'joined', 'by', 'Beck', ',',
                'Fe', '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z',
                '##ant', ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper',
                ',', 'O', "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul',
                '##in', ',', 'and', 'The', 'Tonight', 'Show', 'Band', 'to',
                'perform', 'the', 'L', '##yn', '##yr', '##d', 'Sky', '##ny',
                '##rd', 'song', '`', '`', 'Free', 'Bird', '`', '`', '.',
                '[SEP]'
            ]
        ]
        true_spans_arr = [[[12, 14], [15, 17], [23, 25], [40, 42], [46, 50],
                           [61, 62], [67, 70], [72, 74]],
                          [[46, 50], [61, 62], [67, 70], [72, 74], [76, 78],
                           [81, 84], [93, 95], [100, 102]],
                          [[17, 19], [23, 27], [38, 39], [44, 47], [49, 51],
                           [53, 55], [58, 61], [70, 72], [77, 79]]]
        true_alias_to_predict_arr = [[0, 1, 2, 3, 4, 5], [2, 3, 4, 5], [7, 8]]
        true_aliases_arr = [[
            "tom hanks", "steve carell", "will ferrell", "neil young",
            "long may you run", "beck", "ronnie van zant", "billy gibbons"
        ],
                            [
                                "long may you run", "beck", "ronnie van zant",
                                "billy gibbons", "ben harper", "viveca paulin",
                                "lynyrd skynyrd", "free bird"
                            ],
                            [
                                "neil young", "long may you run", "beck",
                                "ronnie van zant", "billy gibbons",
                                "ben harper", "viveca paulin",
                                "lynyrd skynyrd", "free bird"
                            ]]
        # Run function
        args = parser_utils.get_full_config(
            "test/run_args/test_data_bert.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)
        assert len(idxs_arr) == 3
        assert len(aliases_to_predict_arr) == 3
        assert len(spans_arr) == 3
        assert len(phrase_tokens_arr) == 3

        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Exemple #13
0
    def test_split_sentence_bert(self):

        # Example 1
        max_aliases = 30
        max_seq_len = 20

        # Manual data
        sentence = 'Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5'
        aliases = ["Kittens love", "alias2", "alias5"]
        spans = [[0, 2], [5, 6], [10, 11]]
        aliases_to_predict = [0, 1, 2]

        # Truth
        bert_tokenized = [
            'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp', '##pet',
            '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the',
            'br', '##rea', '##ches', 'alias', '##5'
        ]
        true_phrase_arr = [['[CLS]'] + bert_tokenized + ['[SEP]']]
        true_spans_arr = [[[1, 4], [11, 13], [19, 21]]]
        true_alias_to_predict_arr = [[0, 1, 2]]
        true_aliases_arr = [["Kittens love", "alias2", "alias5"]]

        # Run function
        args = parser_utils.get_full_config(
            "test/run_args/test_data_bert.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        assert len(idxs_arr) == 1
        assert len(aliases_to_predict_arr) == 1
        assert len(spans_arr) == 1
        assert len(phrase_tokens_arr) == 1

        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # Example 2
        max_aliases = 30
        max_seq_len = 7

        # Manual data
        sentence = 'Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5'
        aliases = ["Kittens love", "alias2", "alias5"]
        spans = [[0, 2], [5, 6], [10, 11]]
        aliases_to_predict = [0, 1, 2]

        # Run function
        args = parser_utils.get_full_config(
            "test/run_args/test_data_bert.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # Truth
        true_phrase_arr = [[
            '[CLS]', 'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp',
            '[SEP]'
        ],
                           [
                               '[CLS]', '##eers', 'because', 'alias', '##2',
                               'and', 'spanning', 'the', '[SEP]'
                           ],
                           [
                               '[CLS]', 'spanning', 'the', 'br', '##rea',
                               '##ches', 'alias', '##5', '[SEP]'
                           ]]
        true_spans_arr = [[[1, 4]], [[3, 5]], [[6, 8]]]
        true_alias_to_predict_arr = [[0], [0], [0]]
        true_aliases_arr = [["Kittens love"], ["alias2"], ["alias5"]]

        assert len(idxs_arr) == 3
        assert len(aliases_to_predict_arr) == 3
        assert len(spans_arr) == 3
        assert len(phrase_tokens_arr) == 3
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # Example 3: Test greedy nature of algorithm. It will greedily pack the first two aliases together and the last alias will be split up even though the second alias is also in the second split.
        max_aliases = 30
        max_seq_len = 18

        # Manual data
        sentence = 'Kittens Kittens Kittens Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5'
        aliases = ["Kittens love", "alias2", "alias5"]
        spans = [[3, 5], [8, 9], [13, 14]]
        aliases_to_predict = [0, 1, 2]

        # Run function
        args = parser_utils.get_full_config(
            "test/run_args/test_data_bert.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [[
            '[CLS]', '##tens', 'Kit', '##tens', 'Kit', '##tens', 'love',
            'purple', '##ish', 'pu', '##pp', '##pet', '##eers', 'because',
            'alias', '##2', 'and', 'spanning', 'the', '[SEP]'
        ],
                           [
                               '[CLS]', 'love', 'purple', '##ish', 'pu',
                               '##pp', '##pet', '##eers', 'because', 'alias',
                               '##2', 'and', 'spanning', 'the', 'br', '##rea',
                               '##ches', 'alias', '##5', '[SEP]'
                           ]]
        true_spans_arr = [[[4, 7], [14, 16]], [[9, 11], [17, 19]]]
        true_alias_to_predict_arr = [[0, 1], [1]]
        true_aliases_arr = [["Kittens love", "alias2"], ["alias2", "alias5"]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Exemple #14
0
    def test_real_cases(self):
        # Real examples we messed up

        # EXAMPLE 1
        max_aliases = 30
        max_seq_len = 50

        # 3114|0~*~1~*~2~*~3~*~4~*~5|mexico~*~panama~*~ecuador~*~peru~*~bolivia~*~colombia|3966054~*~22997~*~9334~*~170691~*~3462~*~5222|19:20~*~36:37~*~39:40~*~44:45~*~48:49~*~70:71|The animal is called paca in most of its range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia
        sentence = 'The animal is called paca in most of its range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia'
        aliases = [
            "mexico", "panama", "ecuador", "peru", "bolivia", "colombia"
        ]
        aliases_to_predict = [0, 1, 2, 3, 4, 5]
        spans = [[19, 20], [36, 37], [39, 40], [44, 45], [48, 49], [70, 71]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            'range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media'
            .split(),
            'Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia'
            .split()
        ]
        true_spans_arr = [[[10, 11], [27, 28], [30, 31], [35, 36], [39, 40]],
                          [[15, 16], [18, 19], [23, 24], [27, 28], [49, 50]]]
        true_alias_to_predict_arr = [[0, 1, 2, 3, 4], [4]]
        true_aliases_arr = [["mexico", "panama", "ecuador", "peru", "bolivia"],
                            [
                                "panama", "ecuador", "peru", "bolivia",
                                "colombia"
                            ]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # EXAMPLE 2
        max_aliases = 10
        max_seq_len = 50

        # 20|0~*~1~*~2~*~3~*~4~*~5~*~6~*~7~*~8~*~9~*~10~*~11~*~12~*~13~*~14~*~15~*~16~*~17~*~18~*~19~*~20|coolock~*~swords~*~darndale~*~santry~*~donnycarney~*~baldoyle~*~sutton~*~donaghmede~*~artane~*~whitehall~*~kilbarrack~*~raheny~*~clontarf~*~fairview~*~malahide~*~howth~*~marino~*~ballybough~*~north strand~*~sheriff street~*~east wall|1037463~*~182210~*~8554720~*~2432965~*~7890942~*~1223621~*~1008011~*~3698049~*~1469895~*~2144656~*~3628425~*~1108214~*~1564212~*~1438118~*~944694~*~1037467~*~5745962~*~2436385~*~5310245~*~12170199~*~2814197|12:13~*~14:15~*~15:16~*~17:18~*~18:19~*~19:20~*~20:21~*~21:22~*~22:23~*~23:24~*~24:25~*~25:26~*~26:27~*~27:28~*~28:29~*~29:30~*~30:31~*~38:39~*~39:41~*~41:43~*~43:45|East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall
        sentence = "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall"
        aliases = [
            "coolock", "swords", "darndale", "santry", "donnycarney",
            "baldoyle", "sutton", "donaghmede", "artane", "whitehall",
            "kilbarrack", "raheny", "clontarf", "fairview", "malahide",
            "howth", "marino", "ballybough", "north strand", "sheriff street",
            "east wall"
        ]
        aliases_to_predict = [
            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 171, 8,
            19, 20
        ]
        spans = [[12, 13], [14, 15], [15, 16], [17, 18], [18, 19], [19, 20],
                 [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26],
                 [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [38, 39],
                 [39, 41], [41, 43], [43, 45]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # Truth
        true_phrase_arr = [
            "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>"
            .split(),
            "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>"
            .split(),
            "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>"
            .split()
        ]
        true_spans_arr = [[[12, 13], [14, 15], [15, 16], [17, 18], [18, 19],
                           [19, 20], [20, 21], [21, 22]],
                          [[20, 21], [21, 22], [22, 23], [23, 24], [24, 25],
                           [25, 26], [26, 27], [27, 28], [28, 29]],
                          [[27, 28], [28, 29], [29, 30], [30, 31], [38, 39],
                           [39, 41], [41, 43], [43, 45]]]
        true_alias_to_predict_arr = [[0, 1, 2, 3, 4, 5, 6],
                                     [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 6, 7]]
        true_aliases_arr = [[
            "coolock", "swords", "darndale", "santry", "donnycarney",
            "baldoyle", "sutton", "donaghmede"
        ],
                            [
                                "sutton", "donaghmede", "artane", "whitehall",
                                "kilbarrack", "raheny", "clontarf", "fairview",
                                "malahide"
                            ],
                            [
                                "fairview", "malahide", "howth", "marino",
                                "ballybough", "north strand", "sheriff street",
                                "east wall"
                            ]]

        assert len(idxs_arr) == 3
        assert len(aliases_to_predict_arr) == 3
        assert len(spans_arr) == 3
        assert len(phrase_tokens_arr) == 3
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # Example 2
        max_aliases = 10
        max_seq_len = 100

        # 84|0~*~1|kentucky~*~green|621151~*~478999|8:9~*~9:10|The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools
        sentence = "The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools"
        aliases = ["kentucky", "green"]
        aliases_to_predict = [0, 1]
        spans = [[8, 9], [9, 10]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>"
            .split()
        ]
        true_spans_arr = [[[8, 9], [9, 10]]]
        true_alias_to_predict_arr = [[0, 1]]
        true_aliases_arr = [["kentucky", "green"]]

        assert len(idxs_arr) == 1
        assert len(aliases_to_predict_arr) == 1
        assert len(spans_arr) == 1
        assert len(phrase_tokens_arr) == 1
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Exemple #15
0
    def test_seq_length(self):

        # Test maximum sequence length
        max_aliases = 30
        max_seq_len = 12

        # Manual data
        sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5'
        aliases = ["The big", "alias3", "alias5"]
        aliases_to_predict = [0, 1, 2]
        spans = [[0, 2], [12, 13], [20, 21]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "The big alias1 ran away from dogs and multi word alias2 and".
            split(),
            "word alias2 and alias3 because we want our cat and our alias5".
            split()
        ]
        true_spans_arr = [[[0, 2]], [[3, 4], [11, 12]]]
        true_alias_to_predict_arr = [[0], [0, 1]]
        true_aliases_arr = [["The big"], ["alias3", "alias5"]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # Now test with modified aliases to perdict
        aliases_to_predict = [1, 2]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "word alias2 and alias3 because we want our cat and our alias5".
            split()
        ]
        true_spans_arr = [[[3, 4], [11, 12]]]
        true_alias_to_predict_arr = [[0, 1]]
        true_aliases_arr = [["alias3", "alias5"]]

        assert len(idxs_arr) == 1
        assert len(aliases_to_predict_arr) == 1
        assert len(spans_arr) == 1
        assert len(phrase_tokens_arr) == 1
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])