Example #1
0
def fine_tune(pos_action, neg_action, tokenizer, model):
    nlg_usr = TemplateNLG(is_user=True)
    nlg_sys = TemplateNLG(is_user=False)
    pos_train_usr_utter = []
    pos_train_sys_utter = []
    neg_train_usr_utter = []
    neg_train_sys_utter = []

    for turn in pos_action:
        if turn[0] != [] and turn[1] != []:
            s_u = nlg_usr.generate(turn[0])
            s_a = nlg_sys.generate(turn[1])
            pos_train_usr_utter.append(s_u)
            pos_train_sys_utter.append(s_a)
    for turn in neg_action:
        if turn[0] != [] and turn[1] != []:
            s_u = nlg_usr.generate(turn[0])
            s_a = nlg_sys.generate(turn[1])
            neg_train_usr_utter.append(s_u)
            neg_train_sys_utter.append(s_a)

    train_usr_utter = pos_train_usr_utter + neg_train_usr_utter
    train_sys_utter = pos_train_sys_utter + neg_train_sys_utter

    train_encoding = tokenizer(train_usr_utter,
                               train_sys_utter,
                               padding=True,
                               truncation=True,
                               max_length=80)
    train_encoding['label'] = [1] * len(pos_train_usr_utter) + [0] * len(
        neg_train_usr_utter)
    train_dataset = Dataset.from_dict(train_encoding)
    train_dataset.set_format('torch',
                             columns=['input_ids', 'attention_mask', 'label'])
    save_dir = os.path.join(root_dir,
                            'convlab2/policy/dqn/NLE/save/script_fine_tune')
    log_dir = os.path.join(
        root_dir, 'convlab2/policy/dqn/NLE/save/script_fine_tune/logs')
    training_args = TrainingArguments(
        output_dir=save_dir,
        num_train_epochs=2,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=128,
        warmup_steps=500,
        weight_decay=0.01,
        evaluate_during_training=False,
        logging_dir=log_dir,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
    )
    trainer.train()
    trainer.save_model(os.path.join(save_dir, 'fine_tune_checkpoint'))
Example #2
0
 def __call__(self, examples):
   ds = Dataset.from_dict({'text':examples})
   ds = ds.map(lambda batch: self.tokenizer(batch['text'], truncation=True, padding='max_length'), batched=True, batch_size=512)
   ds.set_format('torch', columns=['input_ids','token_type_ids', 'attention_mask'])
   dataloader = torch.utils.data.DataLoader(ds, batch_size=16)
   res = []
   for batch in tqdm(dataloader):
     batch = {k: v.to(self.device) for k, v in batch.items()}
     outputs = self.model(**batch)
     res.append(outputs[0].softmax(1).detach().cpu())
   return torch.cat(res,dim=0).numpy()
Example #3
0

train_usr_utter, train_sys_utter = generate_data(multiwoz_train)
val_usr_utter, val_sys_utter = generate_data(multiwoz_val)
test_usr_utter, test_sys_utter = generate_data(multiwoz_test)

tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
train_encoding = tokenizer(train_usr_utter,
                           train_sys_utter,
                           padding=True,
                           truncation=True,
                           max_length=80)
train_encoding['label'] = [1] * (len(train_usr_utter) //
                                 2) + [0] * (len(train_usr_utter) // 2)
train_dataset = Dataset.from_dict(train_encoding)
train_dataset.set_format('torch',
                         columns=['input_ids', 'attention_mask', 'label'])
val_encoding = tokenizer(val_usr_utter,
                         val_sys_utter,
                         padding=True,
                         truncation=True,
                         max_length=80)
val_encoding['label'] = [1] * (len(val_usr_utter) //
                               2) + [0] * (len(val_usr_utter) // 2)
val_dataset = Dataset.from_dict(val_encoding)
val_dataset.set_format('torch',
                       columns=['input_ids', 'attention_mask', 'label'])
test_encoding = tokenizer(test_usr_utter,
                          test_sys_utter,
                          padding=True,
Example #4
0
                    one_token, one_tag = line_split
                    # 去掉右侧的换行
                    one_tag = one_tag.rstrip('\n')
                    sentence.append(one_token)
                    sentence_tag.append(one_tag)
                else:
                    print(f"这一行出现问题{line},不是2个字段")
            else:
                #如果是空行,那么说明是下一句了,需要把sentence和 sentence_tag 加入到总的tokens和ner_tags中,然后重置
                if sentence and sentence_tag:
                    tokens.append(sentence)
                    ner_tags.append(sentence_tag)
                    sentence = []
                    sentence_tag = []
    if len(tokens) != len(ner_tags):
        print(f'tokens 和ner_tags的长度不相等,读取的文件有问题,请检查')
        result = {'tokens': [], 'ner_tags': []}
    else:
        result = {'tokens': tokens, 'ner_tags': ner_tags}
    return result


if __name__ == '__main__':
    train_file = "msra/msra_train_bio.txt"
    test_file = "msra/msra_test_bio.txt"
    mini_file = "msra/mini.txt"
    # test_dict = read_ner_txt(test_file)
    # dataset = Dataset.from_dict(test_dict)
    mini_dict = read_ner_txt(mini_file)
    dataset = Dataset.from_dict(mini_dict)
    print(dataset)
Example #5
0
                line_split = line.split('\t')
                if len(line_split) == 2:
                    one_token, one_tag = line_split
                    # 去掉右侧的换行
                    one_tag = one_tag.rstrip('\n')
                    sentence.append(one_token)
                    sentence_tag.append(one_tag)
                else:
                    print(f"这一行出现问题{line},不是2个字段")
            else:
                #如果是空行,那么说明是下一句了,需要把sentence和 sentence_tag 加入到总的tokens和ner_tags中,然后重置
                if sentence and sentence_tag:
                    tokens.append(sentence)
                    ner_tags.append(sentence_tag)
                    sentence = []
                    sentence_tag = []
    if len(tokens) != len(ner_tags):
        print(f'tokens 和ner_tags的长度不相等,读取的文件有问题,请检查')
        result = {'tokens': [], 'ner_tags': []}
    else:
        result = {'tokens': tokens, 'ner_tags': ner_tags}
    return result


if __name__ == '__main__':
    dev_file = "dataset/cosmetics/dev.txt"
    train_file = "dataset/cosmetics/train.txt"
    test_file = "dataset/cosmetics/test.txt"
    test_dict = read_ner_txt(test_file)
    dataset = Dataset.from_dict(test_dict)
    print(dataset)