def load_pretrained(self): history = [] learning_rate = [] best_loss = 10.0 checkpoint = torch.load(self.path_pretrained_model) model = BartForConditionalGeneration(self.config) model.to(self.device) model.load_state_dict(checkpoint['model_state_dict']) optimizer, scheduler = self.get_optim(model) return model, optimizer, scheduler, history, learning_rate, best_loss
def main(args): # If output_dir not provided, a folder will be generated in pwd if not args.output_dir: args.output_dir = os.path.join( "./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}", ) os.makedirs(args.output_dir) model = SummarizationTrainer(args) sd = model.model.state_dict() shorter_pos_embeds = sd['model.encoder.embed_positions.weight'] new_config = model.config new_config.max_position_embeddings = 3076 new_model = BartForConditionalGeneration(new_config) correctly_shaped_pos_weight = new_model.model.encoder.embed_positions.weight.cuda( ) correctly_shaped_pos_weight[:shorter_pos_embeds. shape[0]] = shorter_pos_embeds.cuda() correctly_shaped_pos_weight[shorter_pos_embeds. shape[0]:2052] = shorter_pos_embeds.cuda() correctly_shaped_pos_weight[2052:] = shorter_pos_embeds.cuda() sd['model.decoder.embed_positions.weight'] = correctly_shaped_pos_weight sd['model.encoder.embed_positions.weight'] = correctly_shaped_pos_weight new_model.load_state_dict(sd, strict=True) model.model = new_model.cuda() trainer = generic_train(model, args) # Optionally, predict on dev set and write to output_dir if args.do_predict: # See https://github.com/huggingface/transformers/issues/3159 # pl use this format to create a checkpoint: # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\ # /pytorch_lightning/callbacks/model_checkpoint.py#L169 checkpoints = list( sorted( glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True))) model = model.load_from_checkpoint(checkpoints[-1]) trainer.test(model)
def load(self): history = [] learning_rate = [] best_loss = .0 model = BartForConditionalGeneration(self.config) model.to(self.device) optimizer, scheduler = self.get_optim() check_file = os.path.exists(self.path + 'checkpoint.tar') if check_file: checkpoint = torch.load(self.path + 'checkpoint.tar') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) history = checkpoint['history'] learning_rate = checkpoint['learning_rate'] best_loss = checkpoint['best_loss'] return model, optimizer, scheduler, history, learning_rate, best_loss
correctly_shaped_pos_weight = new_model.model.encoder.embed_positions.weight print(correctly_shaped_pos_weight) # %% for i in range(1): correctly_shaped_pos_weight[i * shorter_pos_embeds.shape[0]: (i + 1) * shorter_pos_embeds.shape[0]] = shorter_pos_embeds correctly_shaped_pos_weight[1 * shorter_pos_embeds.shape[0]:] = shorter_pos_embeds[2:, :] # %% sd['model.decoder.embed_positions.weight'] = torch.tensor(correctly_shaped_pos_weight.data) sd['model.encoder.embed_positions.weight'] = torch.tensor(correctly_shaped_pos_weight.data) new_model.load_state_dict(sd, strict=True) print(new_model.model.encoder.embed_positions.weight) # %% import torch from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig model = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6') tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6') # %% Conditional Generation Example # Mask filling only works for bart-large from transformers import BartTokenizer, BartForConditionalGeneration tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') TXT = "My friends are <mask> but they eat too many carbs."