示例#1
0
    def load_pretrained(self):
        history = []
        learning_rate = []
        best_loss = 10.0

        checkpoint = torch.load(self.path_pretrained_model)
        model = BartForConditionalGeneration(self.config)
        model.to(self.device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer, scheduler = self.get_optim(model)
        return model, optimizer, scheduler, history, learning_rate, best_loss
示例#2
0
def main(args):

    # If output_dir not provided, a folder will be generated in pwd
    if not args.output_dir:
        args.output_dir = os.path.join(
            "./results",
            f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",
        )
        os.makedirs(args.output_dir)
    model = SummarizationTrainer(args)
    sd = model.model.state_dict()
    shorter_pos_embeds = sd['model.encoder.embed_positions.weight']
    new_config = model.config
    new_config.max_position_embeddings = 3076
    new_model = BartForConditionalGeneration(new_config)
    correctly_shaped_pos_weight = new_model.model.encoder.embed_positions.weight.cuda(
    )
    correctly_shaped_pos_weight[:shorter_pos_embeds.
                                shape[0]] = shorter_pos_embeds.cuda()
    correctly_shaped_pos_weight[shorter_pos_embeds.
                                shape[0]:2052] = shorter_pos_embeds.cuda()
    correctly_shaped_pos_weight[2052:] = shorter_pos_embeds.cuda()
    sd['model.decoder.embed_positions.weight'] = correctly_shaped_pos_weight
    sd['model.encoder.embed_positions.weight'] = correctly_shaped_pos_weight
    new_model.load_state_dict(sd, strict=True)
    model.model = new_model.cuda()
    trainer = generic_train(model, args)

    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        checkpoints = list(
            sorted(
                glob.glob(os.path.join(args.output_dir,
                                       "checkpointepoch=*.ckpt"),
                          recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1])
        trainer.test(model)
示例#3
0
    def load(self):

        history = []
        learning_rate = []
        best_loss = .0

        model = BartForConditionalGeneration(self.config)
        model.to(self.device)
        optimizer, scheduler = self.get_optim()

        check_file = os.path.exists(self.path + 'checkpoint.tar')
        if check_file:
            checkpoint = torch.load(self.path + 'checkpoint.tar')

            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            history = checkpoint['history']
            learning_rate = checkpoint['learning_rate']
            best_loss = checkpoint['best_loss']

        return model, optimizer, scheduler, history, learning_rate, best_loss
示例#4
0
correctly_shaped_pos_weight = new_model.model.encoder.embed_positions.weight
print(correctly_shaped_pos_weight)

# %%
for i in range(1):
    correctly_shaped_pos_weight[i * shorter_pos_embeds.shape[0]:
                                (i + 1) * shorter_pos_embeds.shape[0]] = shorter_pos_embeds

correctly_shaped_pos_weight[1 * shorter_pos_embeds.shape[0]:] = shorter_pos_embeds[2:, :]
# %%
sd['model.decoder.embed_positions.weight'] = torch.tensor(correctly_shaped_pos_weight.data)

sd['model.encoder.embed_positions.weight'] = torch.tensor(correctly_shaped_pos_weight.data)

new_model.load_state_dict(sd, strict=True)

print(new_model.model.encoder.embed_positions.weight)

# %%
import torch
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
model = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')
tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6')


# %% Conditional Generation Example
# Mask filling only works for bart-large
from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
TXT = "My friends are <mask> but they eat too many carbs."