def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) check_output_dir(training_args) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", training_args.local_rank, training_args.device, training_args.n_gpu, bool(training_args.parallel_mode == ParallelMode.DISTRIBUTED), training_args.fp16, ) transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() # Set the verbosity to info of the Transformers logger (on main process only): if is_main_process(training_args.local_rank): transformers.utils.logging.set_verbosity_info() logger.info("Training/evaluation parameters %s", training_args) # Set seed set_seed(training_args.seed) # Load pretrained model and tokenizer # # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") for p in extra_model_params: if getattr(training_args, p, None): assert hasattr( config, p ), f"({config.__class__.__name__}) doesn't have a `{p}` attribute" setattr(config, p, getattr(training_args, p)) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) # add custom tokens #special_tokens_list = ["<|TOPIC|>", "<|ARGUMENT|>","<|ASPECTS|>", "<|CONCLUSION|>"] special_tokens_list = ["<|THESIS|>", "<|TITLE|>", "<|SUMMARY|>"] special_tokens_dict = {'additional_special_tokens': special_tokens_list} num_added_tokens = tokenizer.add_special_tokens(special_tokens_dict) logger.info("Added {} new tokens".format(num_added_tokens)) logger.info(tokenizer.special_tokens_map) model = AutoModelForSeq2SeqLM.from_pretrained( model_args.model_name_or_path, from_tf=".ckpt" in model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, ) # resize the model embeddings to include the new to kens model.resize_token_embeddings(len(tokenizer)) # use task specific params use_task_specific_params(model, data_args.task) # set num_beams for evaluation if data_args.eval_beams is None: data_args.eval_beams = model.config.num_beams # set decoder_start_token_id for MBart if model.config.decoder_start_token_id is None and isinstance( tokenizer, (MBartTokenizer, MBartTokenizerFast)): assert (data_args.tgt_lang is not None and data_args.src_lang is not None), "mBart requires --tgt_lang and --src_lang" if isinstance(tokenizer, MBartTokenizer): model.config.decoder_start_token_id = tokenizer.lang_code_to_id[ data_args.tgt_lang] else: model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids( data_args.tgt_lang) if model_args.freeze_embeds: freeze_embeds(model) if model_args.freeze_encoder: freeze_params(model.get_encoder()) assert_all_frozen(model.get_encoder()) dataset_class = Seq2SeqDataset # Get datasets train_dataset = (dataset_class( tokenizer, type_path="train", data_dir=data_args.data_dir, n_obs=data_args.n_train, max_target_length=data_args.max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) if training_args.do_train else None) eval_dataset = (dataset_class( tokenizer, type_path="val", data_dir=data_args.data_dir, n_obs=data_args.n_val, max_target_length=data_args.val_max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) if training_args.do_eval or training_args.evaluation_strategy != EvaluationStrategy.NO else None) test_dataset = (dataset_class( tokenizer, type_path="test", data_dir=data_args.data_dir, n_obs=data_args.n_test, max_target_length=data_args.test_max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) if training_args.do_predict else None) # Initialize our Trainer compute_metrics_fn = (build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None) trainer = Seq2SeqTrainer( model=model, args=training_args, data_args=data_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=Seq2SeqDataCollator(tokenizer, data_args, model.config.decoder_start_token_id, training_args.tpu_num_cores), compute_metrics=compute_metrics_fn, tokenizer=tokenizer, ) all_metrics = {} # Training if training_args.do_train: logger.info("*** Train ***") train_result = trainer.train( model_path=model_args.model_name_or_path if os.path. isdir(model_args.model_name_or_path) else None) metrics = train_result.metrics metrics["train_n_objs"] = data_args.n_train trainer.save_model() # this also saves the tokenizer if trainer.is_world_process_zero(): handle_metrics("train", metrics, training_args.output_dir) all_metrics.update(metrics) # Need to save the state, since Trainer.save_model saves only the tokenizer with the model trainer.state.save_to_json( os.path.join(training_args.output_dir, "trainer_state.json")) # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) tokenizer.save_pretrained(training_args.output_dir) # Evaluation if training_args.do_eval: logger.info("*** Evaluate ***") metrics = trainer.evaluate(metric_key_prefix="val") metrics["val_n_objs"] = data_args.n_val metrics["val_loss"] = round(metrics["val_loss"], 4) if trainer.is_world_process_zero(): handle_metrics("val", metrics, training_args.output_dir) all_metrics.update(metrics) if training_args.do_predict: logger.info("*** Predict ***") test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test") metrics = test_output.metrics metrics["test_n_objs"] = data_args.n_test if trainer.is_world_process_zero(): metrics["test_loss"] = round(metrics["test_loss"], 4) handle_metrics("test", metrics, training_args.output_dir) all_metrics.update(metrics) if training_args.predict_with_generate: test_preds = tokenizer.batch_decode( test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True) test_preds = lmap(str.strip, test_preds) write_txt_file( test_preds, os.path.join(training_args.output_dir, "test_generations.txt")) if trainer.is_world_process_zero(): save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json")) return all_metrics
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", training_args.local_rank, training_args.device, training_args.n_gpu, bool(training_args.local_rank != -1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) # Set seed set_seed(training_args.seed) # Load pretrained model and tokenizer # # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") for p in extra_model_params: if getattr(training_args, p, None): assert hasattr( config, p ), f"({config.__class__.__name__}) doesn't have a `{p}` attribute" setattr(config, p, getattr(training_args, p)) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) model = AutoModelForSeq2SeqLM.from_pretrained( model_args.model_name_or_path, from_tf=".ckpt" in model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, ) # use task specific params use_task_specific_params(model, data_args.task) # set num_beams for evaluation if data_args.eval_beams is None: data_args.eval_beams = model.config.num_beams # set decoder_start_token_id for MBart if model.config.decoder_start_token_id is None and isinstance( tokenizer, MBartTokenizer): assert (data_args.tgt_lang is not None and data_args.src_lang is not None), "mBart requires --tgt_lang and --src_lang" model.config.decoder_start_token_id = tokenizer.lang_code_to_id[ data_args.tgt_lang] if model_args.freeze_embeds: freeze_embeds(model) if model_args.freeze_encoder: freeze_params(model.get_encoder()) assert_all_frozen(model.get_encoder()) dataset_class = Seq2SeqDataset if hasattr( tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset # Get datasets train_dataset = (dataset_class( tokenizer, type_path="train", data_dir=data_args.data_dir, n_obs=data_args.n_train, max_target_length=data_args.max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) if training_args.do_train else None) eval_dataset = (dataset_class( tokenizer, type_path="val", data_dir=data_args.data_dir, n_obs=data_args.n_val, max_target_length=data_args.val_max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) if training_args.do_eval or training_args.evaluation_strategy != EvaluationStrategy.NO else None) test_dataset = (dataset_class( tokenizer, type_path="test", data_dir=data_args.data_dir, n_obs=data_args.n_test, max_target_length=data_args.test_max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) if training_args.do_predict else None) # Initialize our Trainer compute_metrics_fn = (build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None) trainer = Seq2SeqTrainer( model=model, config=config, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), compute_metrics=compute_metrics_fn, data_args=data_args, ) # Training if training_args.do_train: trainer.train(model_path=model_args.model_name_or_path if os.path. isdir(model_args.model_name_or_path) else None) trainer.save_model() # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) if trainer.is_world_process_zero(): trainer.state.save_to_json( os.path.join(training_args.output_dir, "trainer_state.json")) tokenizer.save_pretrained(training_args.output_dir) # Evaluation eval_results = {} if training_args.do_eval: logger.info("*** Evaluate ***") result = trainer.evaluate() if trainer.is_world_process_zero(): logger.info("***** Eval results *****") for key, value in result.items(): logger.info(" %s = %s", key, value) save_json( result, os.path.join(training_args.output_dir, "eval_results.json")) eval_results.update(result) if training_args.do_predict: logging.info("*** Test ***") test_output = trainer.predict(test_dataset=test_dataset) test_metrics = { k.replace("eval", "test"): v for k, v in test_output.metrics.items() } if trainer.is_world_process_zero(): logger.info("***** Test results *****") for key, value in test_metrics.items(): logger.info(" %s = %s", key, value) save_json( test_metrics, os.path.join(training_args.output_dir, "test_results.json")) eval_results.update(test_metrics) if training_args.predict_with_generate: test_preds = tokenizer.batch_decode( test_output.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True) test_preds = lmap(str.strip, test_preds) write_txt_file( test_preds, os.path.join(training_args.output_dir, "test_generations.txt")) if trainer.is_world_process_zero(): save_json(eval_results, "all_results.json") return eval_results
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", training_args.local_rank, training_args.device, training_args.n_gpu, bool(training_args.local_rank != -1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) # Set seed set_seed(training_args.seed) # Load pretrained model and tokenizer # # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) model = AutoModelForSeq2SeqLM.from_pretrained( model_args.model_name_or_path, from_tf=".ckpt" in model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, ) # use task specific params use_task_specific_params(model, data_args.task) # set num_beams for evaluation if data_args.eval_beams is not None: model.config.num_beams = data_args.eval_beams assert model.config.num_beams >= 1, f"got eval_beams={model.config.num_beams}. Need an integer >= 1" # set max length for generation model.config.max_generate_length = data_args.val_max_target_length # set decoder_start_token_id for MBart if model.config.decoder_start_token_id is None and isinstance( tokenizer, MBartTokenizer): decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang] model.config.decoder_start_token_id = decoder_start_token_id def build_compute_metrics_fn( task_name: str) -> Callable[[EvalPrediction], Dict]: def non_pad_len(tokens: np.ndarray) -> int: return np.count_nonzero(tokens != tokenizer.pad_token_id) def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]: pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True) label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) pred_str = lmap(str.strip, pred_str) label_str = lmap(str.strip, label_str) return pred_str, label_str def summarization_metrics(pred: EvalPrediction) -> Dict: pred_str, label_str = decode_pred(pred) rouge: Dict = calculate_rouge(pred_str, label_str) summ_len = np.mean(lmap(non_pad_len, pred.predictions)) rouge.update({"gen_len": summ_len}) return rouge def translation_metrics(pred: EvalPrediction) -> Dict: pred_str, label_str = decode_pred(pred) bleu: Dict = calculate_bleu(pred_str, label_str) gen_len = np.mean(lmap(non_pad_len, pred.predictions)) bleu.update({"gen_len": gen_len}) return bleu compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics return compute_metrics_fn def freeze_embeds(model: torch.nn.Module): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" try: freeze_params(model.model.shared) for d in [model.model.encoder, model.model.decoder]: freeze_params(d.embed_positions) freeze_params(d.embed_tokens) except AttributeError: freeze_params(model.shared) for d in [model.encoder, model.decoder]: freeze_params(d.embed_tokens) if model_args.freeze_embeds: freeze_embeds(model) if model_args.freeze_encoder: freeze_params(model.get_encoder()) assert_all_frozen(model.get_encoder()) dataset_class = Seq2SeqDataset if hasattr( tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset # Get datasets train_dataset = (dataset_class( tokenizer, type_path="train", data_dir=data_args.data_dir, n_obs=data_args.n_train, max_target_length=data_args.max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) if training_args.do_train else None) eval_dataset = (dataset_class( tokenizer, type_path="val", data_dir=data_args.data_dir, n_obs=data_args.n_val, max_target_length=data_args.val_max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) if training_args.do_eval or training_args.evaluation_strategy != EvaluationStrategy.NO else None) test_dataset = (dataset_class( tokenizer, type_path="test", data_dir=data_args.data_dir, n_obs=data_args.n_test, max_target_length=data_args.test_max_target_length, max_source_length=data_args.max_source_length, prefix=model.config.prefix or "", ) if training_args.do_predict else None) # Initialize our Trainer trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores), compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None, ) # Training if training_args.do_train: trainer.train(model_path=model_args.model_name_or_path if os.path. isdir(model_args.model_name_or_path) else None) trainer.save_model() # For convenience, we also re-save the tokenizer to the same directory, # so that you can share your model easily on huggingface.co/models =) if trainer.is_world_process_zero(): tokenizer.save_pretrained(training_args.output_dir) # Evaluation eval_results = {} if training_args.do_eval: logger.info("*** Evaluate ***") result = trainer.evaluate() output_eval_file = os.path.join(training_args.output_dir, "eval_results.json") if trainer.is_world_process_zero(): logger.info("***** Eval results *****") for key, value in result.items(): logger.info(" %s = %s", key, value) with open(output_eval_file, "w") as f: json.dump(result, f) eval_results.update(result) if training_args.do_predict: logging.info("*** Test ***") test_output = trainer.predict(test_dataset=test_dataset) test_metrics = test_output.metrics test_metrics = { k.replace("eval", "test"): v for k, v in test_metrics.items() } output_test_file = os.path.join(training_args.output_dir, "test_results.json") if trainer.is_world_process_zero(): logger.info("***** Test results *****") for key, value in test_metrics.items(): logger.info(" %s = %s", key, value) with open(output_test_file, "w") as f: json.dump(test_metrics, f) if training_args.predict_with_generate: test_preds = tokenizer.batch_decode(test_output.predictions, skip_special_tokens=True) test_preds = lmap(str.strip, test_preds) output_test_pred_file = os.path.join(training_args.output_dir, "test_generations.txt") with open(output_test_pred_file, "w") as f: f.write("\n".join(test_preds)) return eval_results
def compute_metrics(pred): labels_ids = pred.label_ids pred_ids = pred.predictions pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) labels_ids[labels_ids == -100] = tokenizer.pad_token_id label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid return { "rouge2_precision": round(rouge_output.precision, 4), "rouge2_recall": round(rouge_output.recall, 4), "rouge2_fmeasure": round(rouge_output.fmeasure, 4), } trainer = Seq2SeqTrainer( model=bert2bert, args=training_args, compute_metrics=compute_metrics, train_dataset=train_data, eval_dataset=valid_data, ) trainer.train()
def __init__( self, model, pretrain_src1, pretrain_src2, pretrain_tgt, pretrain_enc, pretrain_dec, pretrain_model, epochs, ): self.model = model self.lm_trainer = LMTrainer(model) self.seq2seq_trainer = Seq2SeqTrainer(model) if len(pretrain_src1) > 0 and pretrain_enc: logging.info("Pretraining src1 encoder") self.lm_trainer.train( pretrain_src1, self.model.src1_vocab, self.model.src1_lookup, self.model.enc1_fwd_lstm, self.model.enc1_bwd_lstm, self.model.pret1_w, self.model.pret1_b, attn=False, epochs=epochs, ) if pretrain_src2 and len(pretrain_src2) > 0 and pretrain_enc: logging.info("Pretraining src2 encoder") self.lm_trainer.train( pretrain_src2, self.model.src2_vocab, self.model.src2_lookup, self.model.enc2_fwd_lstm, self.model.enc2_bwd_lstm, self.model.pret2_w, self.model.pret2_b, attn=False, epochs=epochs, ) if len(pretrain_tgt) > 0 and pretrain_dec: logging.info("Pretraining tgt decoder") self.lm_trainer.train( pretrain_tgt, self.model.tgt_vocab, self.model.tgt_lookup, self.model.dec_lstm, None, self.model.dec_w, self.model.dec_b, attn=True, epochs=epochs, ) if pretrain_model: logging.info("Pretraining seq2seq model") self.pretrain_model(pretrain_src1, pretrain_src2, pretrain_tgt, epochs) self.model.save()
) if not config.args.train_only and not config.args.testing: pretrainer = PretrainHandler( model, pretrain_src1=config.args.pretrain_src1, pretrain_src2=config.args.pretrain_src2, pretrain_tgt=config.args.pretrain_tgt, pretrain_enc=config.args.pretrain_enc, pretrain_dec=config.args.pretrain_dec, pretrain_model=config.args.pretrain_s2s, epochs=config.args.pretrain_epochs, ) elif not config.args.pretrain_only and not config.args.testing: trainer = Seq2SeqTrainer(model, output_name=config.output_name) trainer.train_model( train_src1=config.args.train_src1, train_src2=config.args.train_src2, train_tgt=config.args.train_tgt, val_src1=config.args.dev_src1, val_src2=config.args.dev_src2, val_tgt=config.args.dev_tgt, ) elif config.args.testing: tester = Seq2SeqTester(model, output_name=config.output_name) tester.test( src1=config.args.test_src1, src2=config.args.test_src2, tgt=config.args.test_tgt,
def main(args): tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") tokenizer.bos_token = tokenizer.cls_token tokenizer.eos_token = tokenizer.sep_token train_texts, train_labels = read_split('train', args.data_path, args.context) test_texts, test_labels = read_split('test', args.data_path, args.context) train_texts, val_texts, train_labels, val_labels = train_test_split( train_texts, train_labels, test_size=.2) encoder_max_length = 128 decoder_max_length = 128 train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=encoder_max_length) val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=encoder_max_length) test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=encoder_max_length) train_decodings = tokenizer(train_labels, truncation=True, padding=True, max_length=decoder_max_length) val_decodings = tokenizer(val_labels, truncation=True, padding=True, max_length=decoder_max_length) test_decodings = tokenizer(test_labels, truncation=True, padding=True, max_length=decoder_max_length) train_data = ReviewDataset(train_texts, train_labels, train_encodings, train_decodings, tokenizer.pad_token_id) val_data = ReviewDataset(val_texts, val_labels, val_encodings, val_decodings, tokenizer.pad_token_id) test_data = ReviewDataset(test_texts, test_labels, test_encodings, test_decodings, tokenizer.pad_token_id) bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained( "bert-base-uncased", "bert-base-uncased") # set special tokens bert2bert.config.decoder_start_token_id = tokenizer.bos_token_id bert2bert.config.eos_token_id = tokenizer.eos_token_id bert2bert.config.pad_token_id = tokenizer.pad_token_id # sensible parameters for beam search bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size bert2bert.config.max_length = 142 bert2bert.config.min_length = 56 bert2bert.config.no_repeat_ngram_size = 3 bert2bert.config.early_stopping = True bert2bert.config.length_penalty = 2.0 bert2bert.config.num_beams = 4 # set training arguments - these params are not really tuned, feel free to change batch_size = 10 # change to 16 for full training training_args = Seq2SeqTrainingArguments( output_dir="./cptk_yelp", per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, predict_with_generate=True, evaluate_during_training=True, do_train=True, do_eval=True, logging_steps=1000, # set to 1000 for full training save_steps=800, # set to 500 for full training eval_steps=800, # set to 8000 for full training warmup_steps=2000, # set to 2000 for full training overwrite_output_dir=True, save_total_limit=10, fp16=True, num_train_epochs=1000, ) # instantiate trainer trainer = Seq2SeqTrainer( model=bert2bert, args=training_args, compute_metrics=build_compute_metrics_fn, train_dataset=train_data, eval_dataset=val_data, ) trainer.train()