def setUp(self): self.args = parser_utils.get_full_config("test/run_args/test_model_training.json") train_utils.setup_train_heads_and_eval_slices(self.args) self.word_symbols = data_utils.load_wordsymbols(self.args.data_config) self.entity_symbols = EntitySymbols(os.path.join( self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map) slices = WikiSlices( args=self.args, use_weak_label=False, input_src=os.path.join(self.args.data_config.data_dir, "train.jsonl"), dataset_name=os.path.join(self.args.data_config.data_dir, data_utils.generate_save_data_name( data_args=self.args.data_config, use_weak_label=True, split_name="slice_train")), is_writer=True, distributed=self.args.run_config.distributed, dataset_is_eval=False ) self.data = WikiDataset( args=self.args, use_weak_label=False, input_src=os.path.join(self.args.data_config.data_dir, "train.jsonl"), dataset_name=os.path.join(self.args.data_config.data_dir, data_utils.generate_save_data_name( data_args=self.args.data_config, use_weak_label=False, split_name="train")), is_writer=True, distributed=self.args.run_config.distributed, word_symbols=self.word_symbols, entity_symbols=self.entity_symbols, slice_dataset=slices, dataset_is_eval=False ) self.trainer = Trainer(self.args, self.entity_symbols, self.word_symbols)
def setUp(self) -> None: self.args = parser_utils.get_full_config( "test/run_args/test_embeddings.json") self.word_symbols = data_utils.load_wordsymbols(self.args.data_config) self.entity_symbols = EntitySymbolsSubclass() self.word_emb = WordEmbeddingMock( self.args.data_config.word_embedding, self.args, self.word_symbols) self.title_emb = AvgTitleEmb(self.args, self.args.data_config.ent_embeddings[1], "cpu", self.entity_symbols, self.word_symbols, word_emb=self.word_emb, key="avg_title")
def setUp(self) -> None: self.args = parser_utils.get_full_config( "test/run_args/test_embeddings.json") self.args.data_config.ent_embeddings = [ DottedDict( { "key": "learned1", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), DottedDict( { "key": "learned2", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), DottedDict( { "key": "learned3", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), DottedDict( { "key": "learned4", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), DottedDict( { "key": "learned5", "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 5, "tail_init": False } }), ] self.word_symbols = data_utils.load_wordsymbols(self.args.data_config) self.entity_symbols = EntitySymbolsSubclass()
def __init__(self, config_args, device='cuda', max_alias_len=6, cand_map=None, threshold=0.0): self.args = config_args self.device = device self.entity_db = EntitySymbols( os.path.join(self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map) self.word_db = data_utils.load_wordsymbols(self.args.data_config, is_writer=True, distributed=False) self.model = self._load_model() self.max_alias_len = max_alias_len if cand_map is None: alias_map = self.entity_db._alias2qids else: alias_map = ujson.load(open(cand_map)) self.all_aliases_trie = get_all_aliases(alias_map, logger=logging.getLogger()) self.alias_table = AliasEntityTable(args=self.args, entity_symbols=self.entity_db) # minimum probability of prediction to return mention self.threshold = threshold # get batch_on_the_fly embeddings _and_ the batch_prep embeddings self.batch_on_the_fly_embs = {} for i, emb in enumerate(self.args.data_config.ent_embeddings): if 'batch_prep' in emb and emb['batch_prep'] is True: self.args.data_config.ent_embeddings[i][ 'batch_on_the_fly'] = True del self.args.data_config.ent_embeddings[i]['batch_prep'] if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True: mod, load_class = import_class("bootleg.embeddings", emb.load_class) try: self.batch_on_the_fly_embs[emb.key] = getattr( mod, load_class)(main_args=self.args, emb_args=emb['args'], entity_symbols=self.entity_db, model_device=None, word_symbols=None) except AttributeError as e: print( f'No prep method found for {emb.load_class} with error {e}' ) except Exception as e: print("ERROR", e)
def test_edge_case(self): # Edge-case lengths # Test maximum sequence length max_aliases = 30 max_seq_len = 3 # Manual data sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5' aliases = ["The big alias1", "multi word alias2 and alias3"] aliases_to_predict = [0, 1] spans = [[0, 3], [8, 13]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "The big alias1".split(), "multi word alias2".split() ] true_spans_arr = [[[0, 3]], [[0, 5]]] true_alias_to_predict_arr = [[0], [0]] true_aliases_arr = [["The big alias1"], ["multi word alias2 and alias3"]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_split_sentence_alias_to_predict(self): # No splitting nut change in aliases to predict...nothing should change max_aliases = 30 max_seq_len = 24 # Manually created data sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5' aliases = ["The big", "alias3", "alias5"] aliases_to_predict = [0, 1] spans = [[0, 2], [12, 13], [20, 21]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # Truth data true_phrase_arr = [ "The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5 <pad> <pad> <pad>" .split(" ") ] true_spans_arr = [[[0, 2], [12, 13], [20, 21]]] true_alias_to_predict_arr = [[0, 1]] true_aliases_arr = [["The big", "alias3", "alias5"]] assert len(idxs_arr) == 1 assert len(aliases_to_predict_arr) == 1 assert len(spans_arr) == 1 assert len(phrase_tokens_arr) == 1 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_split_sentence_max_aliases(self): # Test if the sentence splits correctly when max_aliases is less than the number of aliases max_aliases = 2 max_seq_len = 24 # Manually created data sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5' aliases = ["The big", "alias3", "alias5"] aliases_to_predict = [0, 1, 2] spans = [[0, 2], [12, 13], [20, 21]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5 <pad> <pad> <pad>" .split(" ") ] * 2 true_spans_arr = [[[0, 2], [12, 13]], [[20, 21]]] true_alias_to_predict_arr = [[0, 1], [0]] true_aliases_arr = [["The big", "alias3"], ["alias5"]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def setUp(self) -> None: self.args = parser_utils.get_full_config( "test/run_args/test_embeddings.json") self.word_symbols = data_utils.load_wordsymbols(self.args.data_config) self.entity_symbols = EntitySymbolsSubclass()
def setUp(self) -> None: #TODO: replace with custom vocab file and not GloVE self.args = parser_utils.get_full_config( "test/run_args/test_embeddings.json") self.word_symbols = data_utils.load_wordsymbols(self.args.data_config)
def model_eval(args, mode, is_writer, logger, world_size=1, rank=0): assert args.run_config.init_checkpoint != "", "You can't have an empty model file to do eval" # this is in main but call again in case eval is called directly train_utils.setup_train_heads_and_eval_slices(args) train_utils.setup_run_folders(args, mode) word_symbols = data_utils.load_wordsymbols( args.data_config, is_writer, distributed=args.run_config.distributed) logger.info(f"Loading entity_symbols...") entity_symbols = EntitySymbols( load_dir=os.path.join(args.data_config.entity_dir, args.data_config.entity_map_dir), alias_cand_map_file=args.data_config.alias_cand_map) logger.info( f"Loaded entity_symbols with {entity_symbols.num_entities} entities.") eval_slice_names = args.run_config.eval_slices test_dataset_collection = {} test_slice_dataset = data_utils.create_slice_dataset( args, args.data_config.test_dataset, is_writer, dataset_is_eval=True) test_dataset = data_utils.create_dataset(args, args.data_config.test_dataset, is_writer, word_symbols, entity_symbols, slice_dataset=test_slice_dataset, dataset_is_eval=True) test_dataloader, test_sampler = data_utils.create_dataloader( args, test_dataset, eval_slice_dataset=test_slice_dataset, batch_size=args.run_config.eval_batch_size) dataset_collection = DatasetCollection(args.data_config.test_dataset, args.data_config.test_dataset.file, test_dataset, test_dataloader, test_slice_dataset, test_sampler) test_dataset_collection[ args.data_config.test_dataset.file] = dataset_collection trainer = Trainer(args, entity_symbols, word_symbols, resume_model_file=args.run_config.init_checkpoint, eval_slice_names=eval_slice_names, model_eval=True) # Run evaluation numbers without dumping predictions (quick, batched) if mode == 'eval': status_reporter = StatusReporter(args, logger, is_writer, max_epochs=None, total_steps_per_epoch=None, is_eval=True) # results are written to json file for test_data_file in test_dataset_collection: logger.info( f"************************RUNNING EVAL {test_data_file}************************" ) test_dataloader = test_dataset_collection[ test_data_file].data_loader # True is for if the batch is test or not, None is for the global step eval_utils.run_batched_eval(args=args, is_test=True, global_step=None, logger=logger, trainer=trainer, dataloader=test_dataloader, status_reporter=status_reporter, file=test_data_file) elif mode == 'dump_preds' or mode == 'dump_embs': # get predictions and optionally dump the corresponding contextual entity embeddings # TODO: support dumping ids for other embeddings as well (static entity embeddings, type embeddings, relation embeddings) # TODO: remove collection abstraction for test_data_file in test_dataset_collection: logger.info( f"************************DUMPING PREDICTIONS FOR {test_data_file}************************" ) test_dataloader = test_dataset_collection[ test_data_file].data_loader pred_file, emb_file = eval_utils.run_dump_preds( args=args, entity_symbols=entity_symbols, test_data_file=test_data_file, logger=logger, trainer=trainer, dataloader=test_dataloader, dump_embs=(mode == 'dump_embs')) return pred_file, emb_file return
def train(args, is_writer, logger, world_size, rank): # This is main but call again in case train is called directly train_utils.setup_train_heads_and_eval_slices(args) train_utils.setup_run_folders(args, "train") # Load word symbols (like tokenizers) and entity symbols (aka entity profiles) word_symbols = data_utils.load_wordsymbols( args.data_config, is_writer, distributed=args.run_config.distributed) logger.info(f"Loading entity_symbols...") entity_symbols = EntitySymbols( load_dir=os.path.join(args.data_config.entity_dir, args.data_config.entity_map_dir), alias_cand_map_file=args.data_config.alias_cand_map) logger.info( f"Loaded entity_symbols with {entity_symbols.num_entities} entities.") # Get train dataset train_slice_dataset = data_utils.create_slice_dataset( args, args.data_config.train_dataset, is_writer, dataset_is_eval=False) train_dataset = data_utils.create_dataset( args, args.data_config.train_dataset, is_writer, word_symbols, entity_symbols, slice_dataset=train_slice_dataset, dataset_is_eval=False) train_dataloader, train_sampler = data_utils.create_dataloader( args, train_dataset, eval_slice_dataset=None, world_size=world_size, rank=rank, batch_size=args.train_config.batch_size) # Repeat for dev dev_dataset_collection = {} dev_slice_dataset = data_utils.create_slice_dataset( args, args.data_config.dev_dataset, is_writer, dataset_is_eval=True) dev_dataset = data_utils.create_dataset(args, args.data_config.dev_dataset, is_writer, word_symbols, entity_symbols, slice_dataset=dev_slice_dataset, dataset_is_eval=True) dev_dataloader, dev_sampler = data_utils.create_dataloader( args, dev_dataset, eval_slice_dataset=dev_slice_dataset, batch_size=args.run_config.eval_batch_size) dataset_collection = DatasetCollection(args.data_config.dev_dataset, args.data_config.dev_dataset.file, dev_dataset, dev_dataloader, dev_slice_dataset, dev_sampler) dev_dataset_collection[ args.data_config.dev_dataset.file] = dataset_collection eval_slice_names = args.run_config.eval_slices total_steps_per_epoch = len(train_dataloader) # Create trainer---model, optimizer, and scorer trainer = Trainer(args, entity_symbols, word_symbols, total_steps_per_epoch=total_steps_per_epoch, eval_slice_names=eval_slice_names, resume_model_file=args.run_config.init_checkpoint) # Set up epochs and intervals for saving and evaluating max_epochs = int(args.run_config.max_epochs) eval_steps = int(args.run_config.eval_steps) log_steps = int(args.run_config.log_steps) save_steps = max(int(args.run_config.save_every_k_eval * eval_steps), 1) logger.info( f"Eval steps {eval_steps}, Log steps {log_steps}, Save steps {save_steps}, Total training examples per epoch {len(train_dataset)}" ) status_reporter = StatusReporter(args, logger, is_writer, max_epochs, total_steps_per_epoch, is_eval=False) global_step = 0 for epoch in range(trainer.start_epoch, trainer.start_epoch + max_epochs): # this is to fix having to save/restore the RNG state for checkpointing torch.manual_seed(args.train_config.seed + epoch) np.random.seed(args.train_config.seed + epoch) if args.run_config.distributed: # for determinism across runs https://github.com/pytorch/examples/issues/501 train_sampler.set_epoch(epoch) start_time_load = time.time() for i, batch in enumerate(train_dataloader): load_time = time.time() - start_time_load start_time = time.time() _, loss_pack, _, _ = trainer.update(batch) # Log progress if (global_step + 1) % log_steps == 0: duration = time.time() - start_time status_reporter.step_status(epoch=epoch, step=global_step, loss_pack=loss_pack, time=duration, load_time=load_time, lr=trainer.get_lr()) # Save model if (global_step + 1) % save_steps == 0 and is_writer: logger.info("Saving model...") trainer.save(save_dir=train_utils.get_save_folder( args.run_config), epoch=epoch, step=global_step, step_in_batch=i, suffix=args.run_config.model_suffix) # Run evaluation if (global_step + 1) % eval_steps == 0: eval_utils.run_eval_all_dev_sets(args, global_step, dev_dataset_collection, logger, status_reporter, trainer) if args.run_config.distributed: dist.barrier() global_step += 1 # Time loading new batch start_time_load = time.time() ######### END OF EPOCH if is_writer: logger.info(f"Saving model end of epoch {epoch}...") trainer.save(save_dir=train_utils.get_save_folder(args.run_config), epoch=epoch, step=global_step, step_in_batch=i, end_of_epoch=True, suffix=args.run_config.model_suffix) # Always run eval when saving -- if this coincided with eval_step, then don't need to rerun eval if (global_step + 1) % eval_steps != 0: eval_utils.run_eval_all_dev_sets(args, global_step, dev_dataset_collection, logger, status_reporter, trainer) if is_writer: logger.info("Saving model...") trainer.save(save_dir=train_utils.get_save_folder(args.run_config), epoch=epoch, step=global_step, step_in_batch=i, end_of_epoch=True, last_epoch=True, suffix=args.run_config.model_suffix) if args.run_config.distributed: dist.barrier()
def test_real_cases_bert(self): # Example 1 max_aliases = 10 max_seq_len = 100 # Manual data sentence = "The guest roster for O'Brien 's final show on January 22\u2014 Tom Hanks , Steve Carell and original first guest Will Ferrell \u2014was regarded by O'Brien as a `` dream lineup '' ; in addition , Neil Young performed his song `` Long May You Run `` and , as the show closed , was joined by Beck , Ferrell ( dressed as Ronnie Van Zant ) , Billy Gibbons , Ben Harper , O'Brien , Viveca Paulin , and The Tonight Show Band to perform the Lynyrd Skynyrd song `` Free Bird `` ." aliases = [ "tom hanks", "steve carell", "will ferrell", "neil young", "long may you run", "beck", "ronnie van zant", "billy gibbons", "ben harper", "viveca paulin", "lynyrd skynyrd", "free bird" ] spans = [[11, 13], [14, 16], [20, 22], [36, 38], [42, 46], [57, 58], [63, 66], [68, 70], [71, 73], [76, 78], [87, 89], [91, 93]] aliases_to_predict = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] # Truth true_phrase_arr = [ [ '[CLS]', 'The', 'guest', 'roster', 'for', 'O', "'", 'Brien', "'", 's', 'final', 'show', 'on', 'January', '22', '—', 'Tom', 'Hank', '##s', ',', 'Steve', 'Care', '##ll', 'and', 'original', 'first', 'guest', 'Will', 'Fe', '##rrell', '—', 'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`', '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',', 'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long', 'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show', 'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe', '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant', ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O', "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',', '[SEP]' ], [ '[CLS]', 'The', 'guest', 'roster', 'for', 'O', "'", 'Brien', "'", 's', 'final', 'show', 'on', 'January', '22', '—', 'Tom', 'Hank', '##s', ',', 'Steve', 'Care', '##ll', 'and', 'original', 'first', 'guest', 'Will', 'Fe', '##rrell', '—', 'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`', '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',', 'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long', 'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show', 'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe', '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant', ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O', "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',', '[SEP]' ], [ '[CLS]', 'original', 'first', 'guest', 'Will', 'Fe', '##rrell', '—', 'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`', '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',', 'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long', 'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show', 'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe', '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant', ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O', "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',', 'and', 'The', 'Tonight', 'Show', 'Band', 'to', 'perform', 'the', 'L', '##yn', '##yr', '##d', 'Sky', '##ny', '##rd', 'song', '`', '`', 'Free', 'Bird', '`', '`', '.', '[SEP]' ] ] true_spans_arr = [[[12, 14], [15, 17], [23, 25], [40, 42], [46, 50], [61, 62], [67, 70], [72, 74]], [[46, 50], [61, 62], [67, 70], [72, 74], [76, 78], [81, 84], [93, 95], [100, 102]], [[17, 19], [23, 27], [38, 39], [44, 47], [49, 51], [53, 55], [58, 61], [70, 72], [77, 79]]] true_alias_to_predict_arr = [[0, 1, 2, 3, 4, 5], [2, 3, 4, 5], [7, 8]] true_aliases_arr = [[ "tom hanks", "steve carell", "will ferrell", "neil young", "long may you run", "beck", "ronnie van zant", "billy gibbons" ], [ "long may you run", "beck", "ronnie van zant", "billy gibbons", "ben harper", "viveca paulin", "lynyrd skynyrd", "free bird" ], [ "neil young", "long may you run", "beck", "ronnie van zant", "billy gibbons", "ben harper", "viveca paulin", "lynyrd skynyrd", "free bird" ]] # Run function args = parser_utils.get_full_config( "test/run_args/test_data_bert.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) assert len(idxs_arr) == 3 assert len(aliases_to_predict_arr) == 3 assert len(spans_arr) == 3 assert len(phrase_tokens_arr) == 3 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_split_sentence_bert(self): # Example 1 max_aliases = 30 max_seq_len = 20 # Manual data sentence = 'Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5' aliases = ["Kittens love", "alias2", "alias5"] spans = [[0, 2], [5, 6], [10, 11]] aliases_to_predict = [0, 1, 2] # Truth bert_tokenized = [ 'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp', '##pet', '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the', 'br', '##rea', '##ches', 'alias', '##5' ] true_phrase_arr = [['[CLS]'] + bert_tokenized + ['[SEP]']] true_spans_arr = [[[1, 4], [11, 13], [19, 21]]] true_alias_to_predict_arr = [[0, 1, 2]] true_aliases_arr = [["Kittens love", "alias2", "alias5"]] # Run function args = parser_utils.get_full_config( "test/run_args/test_data_bert.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) assert len(idxs_arr) == 1 assert len(aliases_to_predict_arr) == 1 assert len(spans_arr) == 1 assert len(phrase_tokens_arr) == 1 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # Example 2 max_aliases = 30 max_seq_len = 7 # Manual data sentence = 'Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5' aliases = ["Kittens love", "alias2", "alias5"] spans = [[0, 2], [5, 6], [10, 11]] aliases_to_predict = [0, 1, 2] # Run function args = parser_utils.get_full_config( "test/run_args/test_data_bert.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # Truth true_phrase_arr = [[ '[CLS]', 'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp', '[SEP]' ], [ '[CLS]', '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the', '[SEP]' ], [ '[CLS]', 'spanning', 'the', 'br', '##rea', '##ches', 'alias', '##5', '[SEP]' ]] true_spans_arr = [[[1, 4]], [[3, 5]], [[6, 8]]] true_alias_to_predict_arr = [[0], [0], [0]] true_aliases_arr = [["Kittens love"], ["alias2"], ["alias5"]] assert len(idxs_arr) == 3 assert len(aliases_to_predict_arr) == 3 assert len(spans_arr) == 3 assert len(phrase_tokens_arr) == 3 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # Example 3: Test greedy nature of algorithm. It will greedily pack the first two aliases together and the last alias will be split up even though the second alias is also in the second split. max_aliases = 30 max_seq_len = 18 # Manual data sentence = 'Kittens Kittens Kittens Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5' aliases = ["Kittens love", "alias2", "alias5"] spans = [[3, 5], [8, 9], [13, 14]] aliases_to_predict = [0, 1, 2] # Run function args = parser_utils.get_full_config( "test/run_args/test_data_bert.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [[ '[CLS]', '##tens', 'Kit', '##tens', 'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp', '##pet', '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the', '[SEP]' ], [ '[CLS]', 'love', 'purple', '##ish', 'pu', '##pp', '##pet', '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the', 'br', '##rea', '##ches', 'alias', '##5', '[SEP]' ]] true_spans_arr = [[[4, 7], [14, 16]], [[9, 11], [17, 19]]] true_alias_to_predict_arr = [[0, 1], [1]] true_aliases_arr = [["Kittens love", "alias2"], ["alias2", "alias5"]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_real_cases(self): # Real examples we messed up # EXAMPLE 1 max_aliases = 30 max_seq_len = 50 # 3114|0~*~1~*~2~*~3~*~4~*~5|mexico~*~panama~*~ecuador~*~peru~*~bolivia~*~colombia|3966054~*~22997~*~9334~*~170691~*~3462~*~5222|19:20~*~36:37~*~39:40~*~44:45~*~48:49~*~70:71|The animal is called paca in most of its range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia sentence = 'The animal is called paca in most of its range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia' aliases = [ "mexico", "panama", "ecuador", "peru", "bolivia", "colombia" ] aliases_to_predict = [0, 1, 2, 3, 4, 5] spans = [[19, 20], [36, 37], [39, 40], [44, 45], [48, 49], [70, 71]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ 'range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media' .split(), 'Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia' .split() ] true_spans_arr = [[[10, 11], [27, 28], [30, 31], [35, 36], [39, 40]], [[15, 16], [18, 19], [23, 24], [27, 28], [49, 50]]] true_alias_to_predict_arr = [[0, 1, 2, 3, 4], [4]] true_aliases_arr = [["mexico", "panama", "ecuador", "peru", "bolivia"], [ "panama", "ecuador", "peru", "bolivia", "colombia" ]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # EXAMPLE 2 max_aliases = 10 max_seq_len = 50 # 20|0~*~1~*~2~*~3~*~4~*~5~*~6~*~7~*~8~*~9~*~10~*~11~*~12~*~13~*~14~*~15~*~16~*~17~*~18~*~19~*~20|coolock~*~swords~*~darndale~*~santry~*~donnycarney~*~baldoyle~*~sutton~*~donaghmede~*~artane~*~whitehall~*~kilbarrack~*~raheny~*~clontarf~*~fairview~*~malahide~*~howth~*~marino~*~ballybough~*~north strand~*~sheriff street~*~east wall|1037463~*~182210~*~8554720~*~2432965~*~7890942~*~1223621~*~1008011~*~3698049~*~1469895~*~2144656~*~3628425~*~1108214~*~1564212~*~1438118~*~944694~*~1037467~*~5745962~*~2436385~*~5310245~*~12170199~*~2814197|12:13~*~14:15~*~15:16~*~17:18~*~18:19~*~19:20~*~20:21~*~21:22~*~22:23~*~23:24~*~24:25~*~25:26~*~26:27~*~27:28~*~28:29~*~29:30~*~30:31~*~38:39~*~39:41~*~41:43~*~43:45|East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall sentence = "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall" aliases = [ "coolock", "swords", "darndale", "santry", "donnycarney", "baldoyle", "sutton", "donaghmede", "artane", "whitehall", "kilbarrack", "raheny", "clontarf", "fairview", "malahide", "howth", "marino", "ballybough", "north strand", "sheriff street", "east wall" ] aliases_to_predict = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 171, 8, 19, 20 ] spans = [[12, 13], [14, 15], [15, 16], [17, 18], [18, 19], [19, 20], [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [38, 39], [39, 41], [41, 43], [43, 45]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # Truth true_phrase_arr = [ "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>" .split(), "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>" .split(), "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>" .split() ] true_spans_arr = [[[12, 13], [14, 15], [15, 16], [17, 18], [18, 19], [19, 20], [20, 21], [21, 22]], [[20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29]], [[27, 28], [28, 29], [29, 30], [30, 31], [38, 39], [39, 41], [41, 43], [43, 45]]] true_alias_to_predict_arr = [[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 6, 7]] true_aliases_arr = [[ "coolock", "swords", "darndale", "santry", "donnycarney", "baldoyle", "sutton", "donaghmede" ], [ "sutton", "donaghmede", "artane", "whitehall", "kilbarrack", "raheny", "clontarf", "fairview", "malahide" ], [ "fairview", "malahide", "howth", "marino", "ballybough", "north strand", "sheriff street", "east wall" ]] assert len(idxs_arr) == 3 assert len(aliases_to_predict_arr) == 3 assert len(spans_arr) == 3 assert len(phrase_tokens_arr) == 3 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # Example 2 max_aliases = 10 max_seq_len = 100 # 84|0~*~1|kentucky~*~green|621151~*~478999|8:9~*~9:10|The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools sentence = "The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools" aliases = ["kentucky", "green"] aliases_to_predict = [0, 1] spans = [[8, 9], [9, 10]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>" .split() ] true_spans_arr = [[[8, 9], [9, 10]]] true_alias_to_predict_arr = [[0, 1]] true_aliases_arr = [["kentucky", "green"]] assert len(idxs_arr) == 1 assert len(aliases_to_predict_arr) == 1 assert len(spans_arr) == 1 assert len(phrase_tokens_arr) == 1 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_seq_length(self): # Test maximum sequence length max_aliases = 30 max_seq_len = 12 # Manual data sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5' aliases = ["The big", "alias3", "alias5"] aliases_to_predict = [0, 1, 2] spans = [[0, 2], [12, 13], [20, 21]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "The big alias1 ran away from dogs and multi word alias2 and". split(), "word alias2 and alias3 because we want our cat and our alias5". split() ] true_spans_arr = [[[0, 2]], [[3, 4], [11, 12]]] true_alias_to_predict_arr = [[0], [0, 1]] true_aliases_arr = [["The big"], ["alias3", "alias5"]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # Now test with modified aliases to perdict aliases_to_predict = [1, 2] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "word alias2 and alias3 because we want our cat and our alias5". split() ] true_spans_arr = [[[3, 4], [11, 12]]] true_alias_to_predict_arr = [[0, 1]] true_aliases_arr = [["alias3", "alias5"]] assert len(idxs_arr) == 1 assert len(aliases_to_predict_arr) == 1 assert len(spans_arr) == 1 assert len(phrase_tokens_arr) == 1 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])