def load_and_cache_examples(args, task, tokenizer): # similar to that in main.py processor = ABSAProcessor() # Load data features from cache or dataset file cached_features_file = os.path.join( args.data_dir, 'cached_{}_{}_{}_{}'.format( 'test', list(filter(None, args.model_name_or_path.split('/'))).pop(), str(args.max_seq_length), str(task))) if os.path.exists(cached_features_file): print("cached_features_file:", cached_features_file) features = torch.load(cached_features_file) examples = processor.get_test_examples(args.data_dir, args.tagging_schema) else: #logger.info("Creating features from dataset file at %s", args.data_dir) label_list = processor.get_labels(args.tagging_schema) examples = processor.get_test_examples(args.data_dir, args.tagging_schema) features = convert_examples_to_seq_features( examples=examples, label_list=label_list, tokenizer=tokenizer, cls_token_at_end=bool(args.model_type in ['xlnet']), cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token, cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0, pad_on_left=bool(args.model_type in ['xlnet']), pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0) torch.save(features, cached_features_file) total_words = [] for input_example in examples: text = input_example.text_a total_words.append(text.split(' ')) # Convert to Tensors and build dataset all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long) # used in evaluation all_evaluate_label_ids = [f.evaluate_label_ids for f in features] dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) return dataset, all_evaluate_label_ids, total_words
def convert_to_dataset(args, examples, tokenizer): processor = ABSAProcessor() label_list = processor.get_labels(args.tagging_schema) normal_labels = processor.get_normal_labels(args.tagging_schema) features, imp_words = convert_examples_to_seq_features( examples=examples, label_list=(label_list, normal_labels), tokenizer=tokenizer, cls_token_at_end=False, cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token, cls_token_segment_id=0, pad_on_left=False, pad_token_segment_id=0) idxs = torch.arange(len(features)) dataset = ABSADataset(features, idxs) return dataset