def test_shape_on_random_data(self): set_seed(42) bs = 3 src_len = 5 tgt_len = 7 encoder_config = transformers.BertConfig( hidden_size=11, intermediate_size=44, vocab_size=17, num_hidden_layers=1, num_attention_heads=1, ) encoder = transformers.BertModel(encoder_config) # decoder accepts vocabulary of schema vocab + pointer embeddings decoder_config = transformers.BertConfig( hidden_size=11, intermediate_size=44, vocab_size=23, is_decoder=True, num_hidden_layers=1, num_attention_heads=1, ) decoder = transformers.BertModel(decoder_config) # logits are projected into schema vocab and combined with pointer scores max_pointer = src_len + 3 model = EncoderDecoderWPointerModel(encoder=encoder, decoder=decoder, max_src_len=max_pointer) x_enc = torch.randint(0, encoder_config.vocab_size, size=(bs, src_len)) x_dec = torch.randint(0, decoder_config.vocab_size, size=(bs, tgt_len)) out = model(input_ids=x_enc, decoder_input_ids=x_dec) # different encoders return different number of outputs # e.g. BERT returns two, but DistillBERT only one self.assertGreaterEqual(len(out), 4) schema_vocab = decoder_config.vocab_size - max_pointer combined_logits = out[0] expected_shape = (bs, tgt_len, schema_vocab + src_len) self.assertEqual(combined_logits.shape, expected_shape) decoder_hidden = out[1] expected_shape = (bs, tgt_len, decoder_config.hidden_size) self.assertEqual(decoder_hidden.shape, expected_shape) combined_logits = out[2] expected_shape = (bs, decoder_config.hidden_size) self.assertEqual(combined_logits.shape, expected_shape) encoder_hidden = out[3] expected_shape = (bs, src_len, encoder_config.hidden_size) self.assertEqual(encoder_hidden.shape, expected_shape)
def setUp(self): utils.set_seed(3) src_tokenizer = transformers.AutoTokenizer.from_pretrained( "bert-base-cased") vocab = { "[", "]", "IN:", "SL:", "GET_DIRECTIONS", "DESTINATION", "DATE_TIME_DEPARTURE", "GET_ESTIMATED_ARRIVAL", } self.schema_tokenizer = TopSchemaTokenizer(vocab, src_tokenizer) self.model = EncoderDecoderWPointerModel.from_parameters( layers=2, hidden=32, heads=2, src_vocab_size=src_tokenizer.vocab_size, tgt_vocab_size=self.schema_tokenizer.vocab_size, max_src_len=17, dropout=0.1, ) source_texts = [ "Directions to Lowell", "Get directions to Mountain View", ] target_texts = [ "[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell]]", "[IN:GET_DIRECTIONS Get directions to [SL:DESTINATION Mountain View]]", ] pairs = [ self.schema_tokenizer.encode_pair(t, s) for t, s in zip(target_texts, source_texts) ] self.dataset = PointerDataset.from_pair_items(pairs) self.dataset.torchify() collator = Seq2SeqDataCollator( pad_id=self.schema_tokenizer.pad_token_id) dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=2, collate_fn=collator.collate_batch) self.test_batch = next(iter(dataloader)) self.module = PointerModule( model=self.model, schema_tokenizer=self.schema_tokenizer, train_dataset=self.dataset, valid_dataset=self.dataset, lr=1e-3, )
def test_smoothing(self): # only checks that it does not fail set_seed(98) v_size = 43 eps = 0.1 preds = torch.randn(size=(7, 19, v_size)).view(-1, v_size) labels = torch.randint(43, size=(7, 19)).view(-1) mask = torch.ones_like(labels) ce2 = LabelSmoothedCrossEntropy(eps=eps)(preds, labels, mask)
def test_no_smoothing(self): set_seed(98) v_size = 43 preds = torch.randn(size=(7, 19, v_size)).view(-1, v_size) labels = torch.randint(43, size=(7, 19)).view(-1) mask = torch.ones_like(labels) ce1 = F.cross_entropy(preds, labels) ce2 = LabelSmoothedCrossEntropy(eps=0)(preds, labels, mask) self.assertTrue(torch.allclose(ce1, ce2))
def test_shape_on_real_data_batched(self): set_seed(42) src_vocab_size = 17 tgt_vocab_size = 23 max_position = 7 encoder_config = transformers.BertConfig( hidden_size=11, intermediate_size=44, vocab_size=src_vocab_size, num_hidden_layers=1, num_attention_heads=1, ) encoder = transformers.BertModel(encoder_config) decoder_config = transformers.BertConfig( hidden_size=11, intermediate_size=44, vocab_size=tgt_vocab_size + max_position, is_decoder=True, num_hidden_layers=1, num_attention_heads=1, ) decoder = transformers.BertModel(decoder_config) model = EncoderDecoderWPointerModel(encoder=encoder, decoder=decoder, max_src_len=max_position) # similar to real data src_seq = torch.LongTensor([[1, 6, 12, 15, 2, 0, 0], [1, 6, 12, 15, 5, 3, 2]]) tgt_seq = torch.LongTensor([ [8, 6, 4, 10, 11, 8, 5, 1, 12, 7, 7, 0, 0], [8, 6, 4, 10, 11, 8, 5, 1, 12, 13, 14, 7, 7], ]) mask = torch.FloatTensor([[0, 1, 1, 1, 0, 0, 0], [0, 1, 1, 1, 1, 1, 0]]) combined_logits = model(input_ids=src_seq, decoder_input_ids=tgt_seq, pointer_mask=mask)[0] expected_shape = (2, tgt_seq.shape[1], tgt_vocab_size + src_seq.shape[1]) self.assertEqual(combined_logits.shape, expected_shape)
def test_getitem(self): utils.set_seed(29) d_mask = 2 src_tensors = [ torch.tensor([2, 5, 4, 4, 2]), torch.tensor([1, 8, 2, 8, 5, 4, 2, 2, 5, 7]), torch.tensor([4, 6, 4, 2, 1, 2]), ] tgt_tensors = [ torch.tensor([6, 7, 8, 7, 2, 2, 4, 8, 5]), torch.tensor([5, 2, 2, 8, 7, 3, 5, 4, 2, 2, 1]), torch.tensor([8, 2, 2, 3, 5, 2, 2, 2, 3, 8, 4, 6, 7, 8]), ] tgt_masks = [(tgt_tensors[i] == d_mask).type(torch.FloatTensor) for i in range(3)] expected_decoder_input_ids = [ torch.tensor([6, 7, 8, 7, 2, 2, 4, 8]), torch.tensor([5, 2, 2, 8, 7, 3, 5, 4, 2, 2]), torch.tensor([8, 2, 2, 3, 5, 2, 2, 2, 3, 8, 4, 6, 7]), ] expected_decoder_pointer_mask = [ torch.tensor([0, 0, 0, 0, 1, 1, 0, 0]), torch.tensor([0, 1, 1, 0, 0, 0, 0, 0, 1, 1]), torch.tensor([0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0]), ] dataset = nsp.data.PointerDataset(src_tensors, tgt_tensors, target_pointer_masks=tgt_masks) item: InputDataClass = dataset[0] self.assertIsInstance(item, InputDataClass) self.assertIsInstance(item.input_ids, torch.LongTensor) self.assertIsInstance(item.decoder_input_ids, torch.LongTensor) self.assertIsInstance(item.labels, torch.LongTensor) self.assertTrue(torch.all(item.input_ids == src_tensors[0])) self.assertTrue( torch.all(item.decoder_input_ids == expected_decoder_input_ids[0])) self.assertTrue( torch.all( item.decoder_pointer_mask == expected_decoder_pointer_mask[0]))
def test_shape_on_real_data(self): set_seed(42) src_vocab_size = 17 tgt_vocab_size = 23 max_position = 5 encoder_config = transformers.BertConfig( hidden_size=11, intermediate_size=44, vocab_size=src_vocab_size, num_hidden_layers=1, num_attention_heads=1, ) encoder = transformers.BertModel(encoder_config) decoder_config = transformers.BertConfig( hidden_size=11, intermediate_size=44, vocab_size=tgt_vocab_size + max_position, is_decoder=True, num_hidden_layers=1, num_attention_heads=1, ) decoder = transformers.BertModel(decoder_config) model = EncoderDecoderWPointerModel(encoder=encoder, decoder=decoder, max_src_len=max_position) # similar to real data # e.g. '[CLS] Directions to Lowell [SEP]' src_seq = torch.LongTensor([[1, 6, 12, 15, 2]]) # e.g. '[IN:GET_DIRECTIONS Directions to [SL:DESTINATION Lowell]]' tgt_seq = torch.LongTensor([[8, 6, 4, 10, 11, 8, 5, 1, 12, 7, 7]]) mask = torch.FloatTensor([[0, 1, 1, 1, 0]]) combined_logits = model(input_ids=src_seq, decoder_input_ids=tgt_seq, pointer_mask=mask)[0] expected_shape = (1, tgt_seq.shape[1], tgt_vocab_size + src_seq.shape[1]) self.assertEqual(combined_logits.shape, expected_shape)
def test_getitem_inference(self): utils.set_seed(29) src_tensors = [ torch.randint(0, 10, size=(random.randint(5, 13), ), dtype=torch.int64) for _ in range(10) ] dataset = nsp.data.PointerDataset(src_tensors) item = dataset[0] self.assertIsInstance(item, InputDataClass) self.assertIsInstance(item.input_ids, torch.LongTensor) self.assertIsNone(item.decoder_input_ids) self.assertIsNone(item.labels)
def main(args): utils.set_seed(args.seed) if os.path.exists(args.output_dir): raise ValueError(f"output_dir {args.output_dir} already exists") wandb_logger = pl.loggers.WandbLogger(project=args.wandb_project, tags=args.tags) wandb_logger.log_hyperparams(args) logger.info( f"Starting finetuning with args: \n{pprint.pformat(vars(args))}") logger.info("Loading tokenizers") schema_tokenizer = load_tokenizer(args.model_dir, args.data_dir) logger.info("Loading data") train_dataset, eval_dataset = load_data( path=path_join(args.data_dir, "data.pkl"), new_data_amount=args.new_data_amount, old_data_amount=args.old_data_amount, old_data_sampling_method=args.old_data_sampling_method, wandb_logger=wandb_logger, ) train_args = cli_utils.load_saved_args( path_join(args.model_dir, "args.toml")) # NOTE: do not log metrics as hyperparameters wandb_logger.log_hyperparams( {"pretrain_" + k: v for k, v in train_args.items() if k != "metrics"}) wandb_logger.log_hyperparams({"num_total_data": len(train_dataset)}) logger.info("Loading model") model = load_model( model_dir=args.model_dir, dropout=args.dropout, move_norm=args.move_norm, move_norm_p=args.move_norm_p, label_smoothing=args.label_smoothing, weight_consolidation=args.weight_consolidation, ) logger.info("Preparing for training") lightning_module = load_lightning_module( checkpoint_path=train_args["pl_checkpoint_path"], model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, schema_tokenizer=schema_tokenizer, args=args, wandb_logger=wandb_logger, ) # override some of the parameters saved in the Trainer checkpoint_path = modify_checkpoint_for_retraining( train_args["pl_checkpoint_path"], args.output_dir, args.lr, args.weight_decay, args.no_opt_state, lightning_module, ) trainer = load_trainer(checkpoint_path, args, wandb_logger) # get evaluation metrics of the initial model pretrain_metrics = train_args["metrics"] _first_step_metrics = { "epoch": -1, "global_step": -1, **pretrain_metrics["means"], **pretrain_metrics["stdevs"], } wandb_logger.log_metrics(_first_step_metrics, step=-1) wandb_logger.watch(lightning_module, log="all", log_freq=lightning_module.log_every) # --- FIT # call .test to load optimizer state with open(os.devnull, "w") as f, contextlib.redirect_stdout(f): trainer.test(lightning_module, lightning_module.val_dataloader(subset_size=0.01)) cli_utils.check_config(lightning_module, trainer, args, strict=True) if model.config.move_norm is not None: assert torch.allclose(model.get_move_norm(), torch.zeros(1, device=model.device)) trainer.fit(lightning_module) cli_utils.check_config(lightning_module, trainer, args, strict=True) with open(path_join(args.output_dir, "args.toml"), "w") as f: args_dict = {"version": nsp.SAVE_FORMAT_VERSION, **vars(args)} toml.dump(args_dict, f) logger.info("Training finished!") best_model_checkpoint = trainer.checkpoint_callback.last_checkpoint_path if args.average_checkpoints: average_checkpoints(model, args, save_to=os.path.dirname(best_model_checkpoint)) final_metrics, description = cli_utils.evaluate_model( best_model_checkpoint, schema_tokenizer, eval_dataset, prefix="eval", max_len=train_args.get("max_tgt_len", 68), # 68 is max_tgt_len for TOP ) logger.info(description) wandb_logger.log_metrics({ **final_metrics["means"], **final_metrics["stdevs"] }) # Compute RI and RD class_weights = eval_dataset.get_class_frequencies(schema_tokenizer) class_weights = { f"cls/eval_{cls}_tree_path_f1": p for cls, p in class_weights.items() } finetuning_metrics = cli_utils.evaluate_finetuning_procedure( pretrain_metrics, final_metrics, class_weights) wandb_logger.log_metrics(finetuning_metrics) # Compute RI and RD with very small outliers stuff finetuning_metrics0 = cli_utils.evaluate_finetuning_procedure( pretrain_metrics, final_metrics, class_weights, sigma=0.1) finetuning_metrics0 = { k + "_0.1": v for k, v in finetuning_metrics0.items() } wandb_logger.log_metrics(finetuning_metrics0) wandb_logger.close() if args.clean_output: shutil.rmtree(args.output_dir)
def main(args): utils.set_seed(args.seed) if os.path.exists(args.output_dir): raise ValueError(f"output_dir {args.output_dir} already exists") # File structure: # that's text\tthat 's text\t[IN:UNSUPPORTED that 's text] train_path = path_join(path_join(args.data, "train.tsv")) train_data = pd.read_table(train_path, names=["text", "tokens", "schema"]) full_train_data_size = len( train_data) # used to check the train/finetune split finetune_data, finetune_path = None, None schema_vocab = reduce(set.union, map(utils.get_vocab_top_schema, train_data.schema)) if args.split_amount is not None: # finetune part is not used by train script, but used by retrain script logger.info("Splitting the training dataset") train_data, finetune_data = train_finetune_split( train_data, schema_vocab, args.split_amount, args.split_class) os.makedirs(args.output_dir) finetune_path = path_join(args.output_dir, "finetune.tsv") logger.info(f"Saving the finetune_data to {finetune_path}") finetune_data.to_csv(finetune_path, sep="\t", index=False, header=False) train_path = path_join(args.output_dir, "train.tsv") logger.info(f"Saving the modified training set to {train_path}") train_data.to_csv(train_path, sep="\t", index=False, header=False) logger.info("Getting schema vocabulary") if args.split_amount is not None: finetune_schema_vocab = reduce( set.union, map(utils.get_vocab_top_schema, finetune_data.schema)) vocab_delta = finetune_schema_vocab - schema_vocab if len(vocab_delta) > 0: logger.warning( f"Finetuning subset contains vocabulary elements not from the training subset" ) logger.warning(f"New elements: {', '.join(vocab_delta)}") logger.info(f"Schema vocabulary size: {len(schema_vocab)}") logger.info("Building tokenizers") text_tokenizer = transformers.AutoTokenizer.from_pretrained( args.text_tokenizer, use_fast=True) schema_tokenizer = nsp.TopSchemaTokenizer(schema_vocab, text_tokenizer) logger.info("Tokenizing train dataset") train_dataset = nsp.data.make_dataset(train_path, schema_tokenizer) logger.info("Tokenizing validation and test datasets") valid_dataset = nsp.data.make_dataset(path_join(args.data, "eval.tsv"), schema_tokenizer) test_dataset = nsp.data.make_dataset(path_join(args.data, "test.tsv"), schema_tokenizer) finetune_dataset = None if args.split_amount is not None: logger.info("Tokenizing finetune set") finetune_dataset = nsp.data.make_dataset(finetune_path, schema_tokenizer) logger.info(f"Original train set size: {full_train_data_size}") logger.info(f"Reduced train set size: {len(train_dataset)}") logger.info(f"Finetune set size: {len(finetune_dataset)}") train_finetune_data_size = len(train_dataset) + len(finetune_dataset) if train_finetune_data_size != full_train_data_size: raise RuntimeError( f"{train_finetune_data_size} != {full_train_data_size}") logger.info(f"Saving config, data and tokenizer to {args.output_dir}") os.makedirs(args.output_dir, exist_ok=True) with open(path_join(args.output_dir, "args.toml"), "w") as f: args_dict = {"version": nsp.SAVE_FORMAT_VERSION, **vars(args)} toml.dump(args_dict, f) # text tokenizer is saved along with schema_tokenizer model_type = None if not os.path.exists(args.text_tokenizer): model_type = utils.get_model_type(args.text_tokenizer) schema_tokenizer.save(path_join(args.output_dir, "tokenizer"), encoder_model_type=model_type) data_state = { "train_dataset": train_dataset, "valid_dataset": valid_dataset, "test_dataset": test_dataset, "finetune_dataset": finetune_dataset, "version": nsp.SAVE_FORMAT_VERSION, } torch.save(data_state, path_join(args.output_dir, "data.pkl"))
def test_collate_batch_shapes(self): utils.set_seed(29) bs = 3 e_pad = 0 d_pad = 1 d_mask = 2 src_tensors = [ torch.tensor([2, 5, 4, 4, 2]), torch.tensor([1, 8, 2, 8, 5, 4, 2, 2, 5, 7]), torch.tensor([4, 6, 4, 2, 1, 2]), ] tgt_tensors = [ torch.tensor([6, 7, 8, 7, 2, 2, 4, 8, 5]), torch.tensor([5, 2, 2, 8, 7, 3, 5, 4, 2, 2, 1]), torch.tensor([8, 2, 2, 3, 5, 2, 2, 2, 3, 8, 4, 6, 7, 8]), ] tgt_masks = [(tgt_tensors[i] == d_mask).type(torch.FloatTensor) for i in range(3)] expected_input_ids = torch.tensor([ [2, 5, 4, 4, 2, 0, 0, 0, 0, 0], [1, 8, 2, 8, 5, 4, 2, 2, 5, 7], [4, 6, 4, 2, 1, 2, 0, 0, 0, 0], ]) expected_decoder_input_ids = torch.tensor([ [6, 7, 8, 7, 2, 2, 4, 8, 5, 1, 1, 1, 1, 1], [5, 2, 2, 8, 7, 3, 5, 4, 2, 2, 1, 1, 1, 1], [8, 2, 2, 3, 5, 2, 2, 2, 3, 8, 4, 6, 7, 8], ]) expected_decoder_pointer_mask = torch.tensor([ [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], [0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], ]) assert d_mask in tgt_tensors[0] examples = [ InputDataClass( input_ids=src_tensors[i], decoder_input_ids=tgt_tensors[i], decoder_pointer_mask=tgt_masks[i], labels=tgt_tensors[i], ) for i in range(bs) ] collator = nsp.data.Seq2SeqDataCollator(e_pad, d_pad) batch = collator.collate_batch(examples) self.assertEqual(batch["input_ids"].shape, (bs, 10)) self.assertIsInstance(batch["input_ids"], torch.LongTensor) self.assertEqual(batch["input_ids"][0, -1], e_pad) self.assertTrue(torch.all(batch["input_ids"] == expected_input_ids)) self.assertEqual(batch["decoder_input_ids"].shape, (bs, 14)) self.assertIsInstance(batch["decoder_input_ids"], torch.LongTensor) self.assertEqual(batch["decoder_input_ids"][0, -1], d_pad) self.assertTrue( torch.all( batch["decoder_input_ids"] == expected_decoder_input_ids)) self.assertEqual(batch["labels"].shape, (bs, 14)) self.assertIsInstance(batch["labels"], torch.LongTensor) self.assertEqual(batch["decoder_pointer_mask"].shape, (bs, 14)) self.assertIsInstance(batch["decoder_pointer_mask"], torch.FloatTensor) _mask = batch["decoder_pointer_mask"] self.assertTrue(((_mask == 0) | (_mask == 1)).all()) self.assertTrue( torch.all(batch["decoder_pointer_mask"] == expected_decoder_pointer_mask))
def main(args): utils.set_seed(args.seed) wandb_logger = pl.loggers.WandbLogger(project=args.wandb_project, tags=args.tags) wandb_logger.log_hyperparams(args) logger.info(f"Starting training with args: \n{pprint.pformat(vars(args))}") if os.path.exists(args.output_dir): raise ValueError(f"output_dir {args.output_dir} already exists") logger.info("Loading tokenizers") schema_tokenizer = nsp.TopSchemaTokenizer.load( path_join(args.data_dir, "tokenizer")) logger.info("Loading data") datasets = torch.load(path_join(args.data_dir, "data.pkl")) train_dataset: nsp.PointerDataset = datasets["train_dataset"] eval_dataset: nsp.PointerDataset = datasets["valid_dataset"] wandb_logger.log_hyperparams({"num_data": len(train_dataset)}) max_src_len, max_tgt_len = train_dataset.get_max_len() try: preprocess_args = cli_utils.load_saved_args( path_join(args.data_dir, "args.toml")) wandb.config.update( {"preprocess_" + k: v for k, v in preprocess_args.items()}) except FileNotFoundError: preprocess_args = None logger.info("Creating a model") model = make_model(schema_tokenizer, max_src_len, args, preprocess_args) logger.info("Preparing for training") lightning_module = make_lightning_module(model, schema_tokenizer, train_dataset, eval_dataset, max_tgt_len, args, wandb_logger) trainer = make_trainer(args, wandb_logger) # --- FIT cli_utils.check_config(lightning_module, trainer, args) trainer.fit(lightning_module) if args.track_grad_square: lightning_module.model.register_weight_consolidation_buffer() logger.info("Training finished!") # top_k == 1 --> the last checkpoint is the best model assert trainer.checkpoint_callback.save_top_k == 1 logger.info(f"Loading and evaluating the best model") final_metrics, description = cli_utils.evaluate_model( trainer.checkpoint_callback.last_checkpoint_path, schema_tokenizer, eval_dataset, prefix="eval", max_len=max_tgt_len, ) with open(path_join(args.output_dir, "args.toml"), "w") as f: args_dict = { "version": nsp.SAVE_FORMAT_VERSION, "pl_checkpoint_path": trainer.checkpoint_callback.last_checkpoint_path, "metrics": final_metrics, "max_src_len": max_src_len, "max_tgt_len": max_tgt_len, **vars(args), } toml.dump(args_dict, f) logger.info(description) wandb_logger.log_metrics({ **final_metrics["means"], **final_metrics["stdevs"] }) wandb_logger.close()