def __init__(self, args, tokenizer, train_dataset=None, dev_dataset=None, test_dataset=None): self.args = args self.tokenizer = tokenizer self.train_dataset = train_dataset self.dev_dataset = dev_dataset self.test_dataset = test_dataset self.id2label = load_id2label(args.id2label) self.num_labels = len(self.id2label) self.config = BertConfig.from_pretrained( args.model_name_or_path, num_labels=self.num_labels, finetuning_task="VLSP2020-Relex", id2label={str(i): label for i, label in self.id2label.items()}, label2id={label: i for i, label in self.id2label.items()}, ) self.model = BertForSequenceClassification.from_pretrained( args.model_name_or_path, config=self.config) # GPU or CPU self.device = "cuda" if torch.cuda.is_available( ) and not args.no_cuda else "cpu" self.model.to(self.device)
def __init__(self, args, tokenizer, train_dataset=None, dev_dataset=None, test_dataset=None): self.args = args self.tokenizer = tokenizer self.train_dataset = train_dataset self.dev_dataset = dev_dataset self.test_dataset = test_dataset self.id2label = load_id2label(args.id2label) self.num_labels = len(self.id2label) self.config = RobertaConfig.from_pretrained( args.model_name_or_path, num_labels=self.num_labels, finetuning_task="VLSP2020-Relex", id2label={str(i): label for i, label in self.id2label.items()}, label2id={label: i for i, label in self.id2label.items()}, ) if self.args.model_type == "es": self.model = RobertaEntityStarts.from_pretrained( args.model_name_or_path, config=self.config) elif self.args.model_type == "all": self.model = RobertaConcatAll.from_pretrained( args.model_name_or_path, config=self.config) # GPU or CPU self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device)
def main(args): id2label = load_id2label(args.id2label) basename = os.path.splitext(args.input_file)[0] output_file = os.path.join(basename + "-phobert.txt") output_file_readble = os.path.join(basename + "-phobert-readable.txt") fo = open(output_file, "w", encoding="utf-8") fo_readble = open(output_file_readble, "w", encoding="utf-8") with open(args.input_file, "r") as f: for line in f: line = line.strip() if line == "": continue phobert_line, line_with_markers = convert_to_phobert( line, id2label) print(phobert_line, file=fo) print(line_with_markers, file=fo_readble) fo.close() fo_readble.close()