class ChatDataset(Dataset): def __init__(self, filepath, tok_vocab, max_seq_len=128) -> None: self.filepath = filepath self.data = pd.read_csv(self.filepath) self.bos_token = '<s>' self.eos_token = '</s>' self.max_seq_len = max_seq_len self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tok_vocab, bos_token=self.bos_token, eos_token=self.eos_token, unk_token='<unk>', pad_token='<pad>', mask_token='<mask>') def __len__(self): return len(self.data) def make_input_id_mask(self, tokens, index): input_id = self.tokenizer.convert_tokens_to_ids(tokens) attention_mask = [1] * len(input_id) if len(input_id) < self.max_seq_len: while len(input_id) < self.max_seq_len: input_id += [self.tokenizer.pad_token_id] attention_mask += [0] else: # logging.warning(f'exceed max_seq_len for given article : {index}') input_id = input_id[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id] attention_mask = attention_mask[:self.max_seq_len] return input_id, attention_mask def __getitem__(self, index): record = self.data.iloc[index] q, a = record['Q'], record['A'] q_tokens = [self.bos_token] + \ self.tokenizer.tokenize(q) + [self.eos_token] a_tokens = [self.bos_token] + \ self.tokenizer.tokenize(a) + [self.eos_token] encoder_input_id, encoder_attention_mask = self.make_input_id_mask( q_tokens, index) decoder_input_id, decoder_attention_mask = self.make_input_id_mask( a_tokens, index) labels = self.tokenizer.convert_tokens_to_ids( a_tokens[1:(self.max_seq_len + 1)]) if len(labels) < self.max_seq_len: while len(labels) < self.max_seq_len: # for cross entropy loss masking labels += [-100] return { 'input_ids': np.array(encoder_input_id, dtype=np.int_), 'attention_mask': np.array(encoder_attention_mask, dtype=np.float_), 'decoder_input_ids': np.array(decoder_input_id, dtype=np.int_), 'decoder_attention_mask': np.array(decoder_attention_mask, dtype=np.float_), 'labels': np.array(labels, dtype=np.int_) }
def main(args): data = np.load(args.data, allow_pickle=True) tokenizer_path = args.tokenizer tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path, max_len=512, mask_token="<mask>", pad_token="<pad>") tokenizer._tokenizer.post_processor = BertProcessing( ("</s>", tokenizer.convert_tokens_to_ids("</s>")), ("<s>", tokenizer.convert_tokens_to_ids("<s>")), ) config = RobertaConfig( vocab_size=tokenizer.vocab_size, max_position_embeddings=514, num_attention_heads=12, num_hidden_layers=6, type_vocab_size=1, ) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) dataset = PhoneDatasetMLM(data, tokenizer) model = RobertaForMaskedLM(config=config) training_args = TrainingArguments( output_dir=args.output_dir, overwrite_output_dir=True, num_train_epochs=1, per_device_train_batch_size=64, logging_steps=2, save_steps=10_000, save_total_limit=2, prediction_loss_only=True, ) trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=dataset, ) trainer.train() trainer.save_model(args.output_dir)