def get_model_and_tokenizer(args): config_class, tokenizer_class = MODEL_CLASSES[args.model_type] model_config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None) config = BertForSeq2SeqConfig.from_exist_config( config=model_config, label_smoothing=args.label_smoothing, max_position_embeddings=args.max_source_seq_length + args.max_target_seq_length) logger.info("Model config for seq2seq: %s", str(config)) tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) model = BertForSequenceToSequence.from_pretrained( args.model_name_or_path, config=config, model_type=args.model_type, reuse_position_embedding=True, cache_dir=args.cache_dir if args.cache_dir else None) return model, tokenizer
def get_model_and_tokenizer(args): model_config = UnilmConfig.from_pretrained( args.config_name if args.config_name else 'unilm-base-cased', cache_dir=args.cache_dir if args.cache_dir else None) config = BertForSeq2SeqConfig.from_exist_config( config=model_config, label_smoothing=args.label_smoothing, max_position_embeddings=args.max_source_seq_length + args.max_target_seq_length) logger.info("Model config for seq2seq: %s", str(config)) tokenizer = UnilmTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else 'unilm-base-cased', do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) generator = BertForSequenceToSequence.from_pretrained( 'unilm-base-cased', config=config, model_type='unilm', reuse_position_embedding=True, cache_dir=args.cache_dir if args.cache_dir else None) generator.to(args.device) classifer = Classifier(config.hidden_size, args.num_labels) classifer.to(args.device) logger.info("Initialize retriever.") retriever = Retriever(args, tokenizer) return generator, classifer, tokenizer, retriever
def get_model_and_tokenizer(args): config_class, tokenizer_class = MODEL_CLASSES[args.model_type] # Hack to cope with updated version of Transformers API if args.model_type in ['minilm', 'unilm', 'xbert']: config_file = config_class.pretrained_config_archive_map[ args.model_name_or_path] vocab_file = tokenizer_class.pretrained_vocab_files_map['vocab_file'][ args.model_name_or_path] model_file = args.model_name_or_path elif os.path.exists(args.model_name_or_path): vocab_file = os.path.join(args.model_name_or_path, 'vocab.txt') config_file = os.path.join(args.model_name_or_path, 'config.json') model_file = os.path.join(args.model_name_or_path, 'pytorch_model.bin') assert os.path.exists(vocab_file) assert os.path.exists(config_file) assert os.path.exists(model_file) else: vocab_file = args.model_name_or_path config_file = args.model_name_or_path model_file = args.model_name_or_path model_config = config_class.from_pretrained( args.config_name if args.config_name else config_file, cache_dir=args.cache_dir if args.cache_dir else None) config = BertForSeq2SeqConfig.from_exist_config( config=model_config, label_smoothing=args.label_smoothing, max_position_embeddings=args.max_source_seq_length + args.max_target_seq_length) logger.info("Model config for seq2seq: %s", str(config)) tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else vocab_file, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) model = BertForSequenceToSequence.from_pretrained( model_file, config=config, model_type=args.model_type, reuse_position_embedding=True, cache_dir=args.cache_dir if args.cache_dir else None) return model, tokenizer
def get_model_and_tokenizer(args): config_class, tokenizer_class = MODEL_CLASSES[args.model_type] model_config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None) config = BertForSeq2SeqConfig.from_exist_config( config=model_config, label_smoothing=args.label_smoothing, max_position_embeddings=args.max_source_seq_length + args.max_target_seq_length) logger.info("Model config for seq2seq: %s", str(config)) if args.prepend_len: tgt_segments = [85] + list(range(100, 400, 15)) + [400] additional_special_tokens = [f'[unused{seg}]' for seg in tgt_segments] logger.info(f'additional_special_tokens: {additional_special_tokens}') tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, additional_special_tokens=additional_special_tokens) else: tokenizer = tokenizer_class.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) model = BertForSequenceToSequence.from_pretrained( args.model_name_or_path, config=config, model_type=args.model_type, reuse_position_embedding=True, cache_dir=args.cache_dir if args.cache_dir else None) return model, tokenizer
def train(args, training_features, doc_features, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0] and args.log_dir: tb_writer = SummaryWriter(log_dir=args.log_dir) else: tb_writer = None if args.fp16: try: from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") else: amp = None # model recover recover_step = utils.get_max_epoch_model(args.output_dir) # if recover_step: # model_recover_checkpoint = os.path.join(args.output_dir, "model.{}.bin".format(recover_step)) # logger.info(" ** Recover model checkpoint in %s ** ", model_recover_checkpoint) # model_state_dict = torch.load(model_recover_checkpoint, map_location='cpu') # optimizer_recover_checkpoint = os.path.join(args.output_dir, "optim.{}.bin".format(recover_step)) # checkpoint_state_dict = torch.load(optimizer_recover_checkpoint, map_location='cpu') # checkpoint_state_dict['model'] = model_state_dict # else: checkpoint_state_dict = None model.to(args.device) model, optimizer = prepare_for_training(args, model, checkpoint_state_dict, amp=amp) if args.n_gpu == 0 or args.no_cuda: per_node_train_batch_size = args.per_gpu_train_batch_size * args.gradient_accumulation_steps else: per_node_train_batch_size = args.per_gpu_train_batch_size * args.n_gpu * args.gradient_accumulation_steps train_batch_size = per_node_train_batch_size * (torch.distributed.get_world_size() if args.local_rank != -1 else 1) global_step = recover_step if recover_step else 0 if args.num_training_steps == -1: args.num_training_steps = int(args.num_training_epochs * len(training_features) / train_batch_size) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.num_training_steps, last_epoch=-1) if checkpoint_state_dict: scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"]) train_dataset = utils.RetrievalSeq2seqDatasetForBert( features=training_features, max_source_len=args.max_source_seq_length, max_target_len=args.max_target_seq_length, vocab_size=tokenizer.vocab_size, cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id, mask_id=tokenizer.mask_token_id, random_prob=args.random_prob, keep_prob=args.keep_prob, offset=train_batch_size * global_step, num_training_instances=train_batch_size * args.num_training_steps, ) concator = utils.Concator( max_source_len=args.max_source_seq_length, max_target_len=args.max_target_seq_length, vocab_size=tokenizer.vocab_size, cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id, mask_id=tokenizer.mask_token_id, random_prob=args.random_prob, keep_prob=args.keep_prob, offset=train_batch_size * global_step, num_training_instances=train_batch_size * args.num_training_steps, ) if hasattr(model, "module"): model.module.concator = concator else: model.concator = concator # build documents embeds logger.info("Building embeds for %d documents" % len(doc_features)) doc_dataset = utils.RetrievalSeq2seqDocDatasetForBert( features=doc_features, max_source_len=args.max_source_seq_length, max_target_len=args.max_target_seq_length, vocab_size=tokenizer.vocab_size, cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id, mask_id=tokenizer.mask_token_id, random_prob=args.random_prob, keep_prob=args.keep_prob, offset=train_batch_size * global_step, num_training_instances=train_batch_size * args.num_training_steps, ) doc_sampler = SequentialSampler(doc_dataset) doc_dataloader = DataLoader( doc_dataset, sampler=doc_sampler, batch_size=per_node_train_batch_size // args.gradient_accumulation_steps, collate_fn=utils.batch_list_to_batch_tensors) doc_iterator = tqdm.tqdm( doc_dataloader, initial=global_step, desc="Embeding docs:", disable=args.local_rank not in [-1, 0]) all_embeds = [] model.eval() model.zero_grad() for step, batch in enumerate(doc_iterator): batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): embeds = model.module.retrieval.get_embeds(batch[0]) if hasattr(model, "module") else model.retrieval.get_embeds(batch[0]) all_embeds.extend(embeds.view(-1, 768).detach().cpu().tolist()) if hasattr(model, "module"): model.module.retrieval.doc_embeds = torch.tensor(all_embeds, dtype=torch.float32) model.module.retrieval.build_indexs_from_embeds(model.module.retrieval.doc_embeds) else: model.retrieval.doc_embeds = torch.tensor(all_embeds, dtype=torch.float32) model.retrieval.build_indexs_from_embeds(model.retrieval.doc_embeds) logger.info("start training") if args.ckpt_path: logger.info("continue training from %s"%args.ckpt_path) config_class, tokenizer_class = MODEL_CLASSES[args.model_type] model_config = config_class.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None) config = BertForSeq2SeqConfig.from_exist_config( config=model_config, label_smoothing=args.label_smoothing, max_position_embeddings=args.max_source_seq_length + args.max_target_seq_length) model = BertForRetrievalSeq2Seq.from_pretrained( args.ckpt_path, config=config, model_type=args.model_type, reuse_position_embedding=True, retrieval=config, cache_dir=args.cache_dir if args.cache_dir else None) logger.info("Check dataset:") for i in range(5): source_ids, target_ids, num_source_tokens, num_target_tokens = train_dataset.__getitem__(i) logger.info("Instance-%d" % i) logger.info("Source tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(source_ids))) logger.info("Target tokens = %s" % " ".join(tokenizer.convert_ids_to_tokens(target_ids))) logger.info("Mode = %s" % str(model)) # Train! logger.info(" ***** Running training ***** *") logger.info(" Num examples = %d", len(training_features)) logger.info(" Num Epochs = %.2f", len(train_dataset) / len(training_features)) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Batch size per node = %d", per_node_train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", train_batch_size) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", args.num_training_steps) if args.num_training_steps <= global_step: logger.info("Training is done. Please use a new dir or clean this dir!") else: # The training features are shuffled train_sampler = SequentialSampler(train_dataset) \ if args.local_rank == -1 else DistributedSampler(train_dataset, shuffle=False) train_dataloader = DataLoader( train_dataset, sampler=train_sampler, batch_size=per_node_train_batch_size // args.gradient_accumulation_steps, collate_fn=utils.batch_list_to_batch_tensors) train_iterator = tqdm.tqdm( train_dataloader, initial=global_step, desc="Iter (loss=X.XXX, lr=X.XXXXXXX)", disable=args.local_rank not in [-1, 0]) model.train() model.zero_grad() tr_loss, logging_loss = 0.0, 0.0 for step, batch in enumerate(train_iterator): batch = tuple(t.to(args.device) for t in batch) inputs = {'source_ids': batch[0], 'target_ids': batch[1], # 'pseudo_ids': batch[2], 'num_source_tokens': batch[2], 'num_target_tokens': batch[3]} loss = model(**inputs) if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training train_iterator.set_description('Iter (loss=%5.3f) lr=%9.7f' % (loss.item(), scheduler.get_lr()[0])) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() logging_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: logger.info("") logger.info(" Step [%d ~ %d]: %.2f", global_step - args.logging_steps, global_step, logging_loss) logging_loss = 0.0 if args.local_rank in [-1, 0] and args.save_steps > 0 and \ (global_step % args.save_steps == 0 or global_step == args.num_training_steps): save_path = os.path.join(args.output_dir, "ckpt-%d" % global_step) os.makedirs(save_path, exist_ok=True) model_to_save = model.module if hasattr(model, "module") else model model_to_save.save_pretrained(save_path) # optim_to_save = { # "optimizer": optimizer.state_dict(), # "lr_scheduler": scheduler.state_dict(), # } # if args.fp16: # optim_to_save["amp"] = amp.state_dict() # torch.save( # optim_to_save, os.path.join(args.output_dir, 'optim.{}.bin'.format(global_step))) logger.info("Saving model checkpoint %d into %s", global_step, save_path) if args.local_rank in [-1, 0] and tb_writer: tb_writer.close()
def __init__( self, model_name="unilm-base-cased", to_lower=False, cache_dir=".", load_model_from_dir=None, model_file_name=None, label_smoothing=0.1, max_seq_length=512, max_source_seq_length=464, max_target_seq_length=48, ): """ Abstractive summarizer based on s2s-ft. Args: model_name (str, optional): Name of the model. Call `S2SAbstractiveSummarizer.list_supported_models()` to see all supported model names. Defaults to "unilm-base-cased". to_lower (bool, optional): Whether to convert all letters to lower case during tokenization. This is determined by if a cased model is used. Defaults to False, which corresponds to a cased model. cache_dir (str, optional): Directory to cache downloaded model files. Defaults to ".". load_model_from_dir (str, optional): Directory to load the model from. If model_file_name is not provided, assume model was saved by `:func:`~transformers.PreTrainedModel.save_pretrained`` and the directory should contain pytorch_model.bin and config.json. Defaults to None. model_file_name (str, optional): Name of the model file under `load_model_from_dir`. If provided, assume model was saved by `S2SAbstractiveSummarizer.save_model`. label_smoothing (float, optional): Alpha in label smoothing. Defaults to 0.1. max_seq_length (int, optional): Maximum length of the sequence that concatenates source sequence tokens, target sequence tokens, and special tokens like cls and sep. Defaults to 512. max_source_seq_length (int, optional): Maximum number of tokens in the source sequence after tokenization. Defaults to 464. max_target_seq_length (int, optional); Maximum number of tokens in the target sequence after tokenization. Defaults to 48. """ if model_name not in self.list_supported_models(): raise ValueError( "Model name {0} is not supported by {1}. " "Call '{1}.list_supported_models()' to get all supported model " "names.".format(model_name, self.__class__.__name__)) model_class = MODEL_CLASS[model_name] config_class = CONFIG_CLASS[model_name] self._model_name = model_name self._model_type = _get_model_type(self._model_name) # self._bert_model_name is needed for BertForSeq2SeqDecoder if self._model_type != "bert": if self._model_type == "roberta": self._bert_model_name = ( self._model_name.replace("roberta", "bert") + "-cased") else: self._bert_model_name = "bert-" + self._model_name.split( "-", 1)[-1] else: self._bert_model_name = self._model_name self.cache_dir = cache_dir self.load_model_from_dir = load_model_from_dir self.do_lower_case = to_lower self.max_seq_length = max_seq_length self.max_source_seq_length = max_source_seq_length self.max_target_seq_length = max_target_seq_length if load_model_from_dir is None: model_to_load = self._model_name elif model_file_name is None: # Assume model was saved by # `:func:`~transformers.PreTrainedModel.save_pretrained``, # The load_model_from_dir should contain pytorch_model.bin and config.json # and can be loaded by # `:func:`~transformers.PreTrainedModel.from_pretrained``. logger.info( "Loading cached model from {}".format(load_model_from_dir)) model_to_load = load_model_from_dir else: # Assume model was saved by S2SAbstractiveSummarizer.save_model model_to_load = os.path.join(load_model_from_dir, model_file_name) logger.info("Loading cached model from {}".format(model_to_load)) if load_model_from_dir is not None and model_file_name is None: # Assume config.json is in load_model_from_dir model_config = config_class.from_pretrained(load_model_from_dir, cache_dir=cache_dir) else: model_config = config_class.from_pretrained(self._model_name, cache_dir=cache_dir) # Convert regular model config to sequence to sequence config config = BertForSeq2SeqConfig.from_exist_config( config=model_config, label_smoothing=label_smoothing, max_position_embeddings=self.max_source_seq_length + self.max_target_seq_length, ) logger.info("Model config for seq2seq: %s", str(config)) self.model = model_class.from_pretrained( model_to_load, config=config, model_type=self._model_type, cache_dir=cache_dir, reuse_position_embedding=True, ) self.tokenizer = TOKENIZER_CLASS[model_name].from_pretrained( self._model_name, do_lower_case=to_lower, cache_dir=cache_dir, output_loading_info=False, )