Esempio n. 1
0
    def __init__(self, args):
        self.args = args
        self.device = 'cuda:0' if self.args.model_name != 't5-11b' else 'cuda:{}'.format(
            self.args.t5_shard * 4)

        if self.args.use_original_template and (
                not self.args.use_lm_finetune) and (
                    not self.args.only_evaluate):
            raise RuntimeError("""If use args.use_original_template is True, 
            either args.use_lm_finetune or args.only_evaluate should be True."""
                               )

        # load tokenizer
        tokenizer_src = 'roberta-large' if 'megatron' in self.args.model_name else self.args.model_name
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_src,
                                                       use_fast=False)
        init_vocab(args)

        # load datasets and dataloaders
        self.relation, self.data_path_pre, self.data_path_post = self.get_TREx_parameters(
        )

        self.train_data = load_file(
            join(self.args.data_dir,
                 self.data_path_pre + 'train' + self.data_path_post))
        self.dev_data = load_file(
            join(self.args.data_dir,
                 self.data_path_pre + 'dev' + self.data_path_post))
        self.test_data = load_file(
            join(self.args.data_dir,
                 self.data_path_pre + 'test' + self.data_path_post))

        self.test_set = LAMADataset('test', self.test_data, self.tokenizer,
                                    self.args)
        self.train_set = LAMADataset('train', self.train_data, self.tokenizer,
                                     self.args)
        self.dev_set = LAMADataset('dev', self.dev_data, self.tokenizer,
                                   self.args)
        os.makedirs(self.get_save_path(), exist_ok=True)

        self.train_loader = DataLoader(self.train_set,
                                       batch_size=8,
                                       shuffle=True,
                                       drop_last=True)
        self.dev_loader = DataLoader(self.dev_set, batch_size=8)
        self.test_loader = DataLoader(self.test_set, batch_size=8)

        self.model = PTuneForLAMA(args, self.device, self.args.template)
Esempio n. 2
0
 def get_TREx_parameters(self):
     relation = load_file(
         join(self.args.data_dir,
              "single_relations/{}.jsonl".format(self.args.relation_id)))[0]
     data_path_pre = "fact-retrieval/original/{}/".format(
         self.args.relation_id)
     data_path_post = ".jsonl"
     return relation, data_path_pre, data_path_post
Esempio n. 3
0
    def __init__(self, args, device, template):
        super().__init__()
        self.args = args
        self.device = device

        # load relation templates
        self.relation_templates = dict(
            (d['relation'], d['template'])
            for d in load_file(join(self.args.data_dir, 'relations.jsonl')))

        # load tokenizer
        tokenizer_src = 'roberta-large' if 'megatron' in self.args.model_name else self.args.model_name
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_src,
                                                       use_fast=False)

        # load pre-trained model
        if 'megatron' in self.args.model_name and self.args.use_lm_finetune:
            raise RuntimeError(
                "Can not apply args.use_lm_finetune=True on MegatronLM 11B.")
        self.model = create_model(self.args)
        self.model = self.model.to(self.device)
        for param in self.model.parameters():
            param.requires_grad = self.args.use_lm_finetune
        self.embeddings = get_embedding_layer(self.args, self.model)

        # set allowed vocab set
        self.vocab = self.tokenizer.get_vocab()
        self.allowed_vocab_ids = set(
            self.vocab[k]
            for k in get_vocab_by_strategy(self.args, self.tokenizer))

        if 'gpt' in self.args.model_name or 'megatron' in self.args.model_name:
            template = (template[0], template[1], 0)
        self.template = template

        # load prompt encoder
        self.hidden_size = self.embeddings.embedding_dim
        self.tokenizer.add_special_tokens(
            {'additional_special_tokens': [self.args.pseudo_token]})
        self.pseudo_token_id = self.tokenizer.get_vocab()[
            self.args.pseudo_token]
        if self.tokenizer.pad_token_id is not None:
            self.pad_token_id = self.tokenizer.pad_token_id
        else:
            self.pad_token_id = self.tokenizer.unk_token_id

        self.spell_length = sum(self.template)
        self.prompt_encoder = PromptEncoder(self.template, self.hidden_size,
                                            self.tokenizer, self.device, args)
        self.prompt_encoder = self.prompt_encoder.to(self.device)