def test_full_tokenizer(self): tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True) tokens = tokenizer.tokenize("This is a test") self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") self.assertListEqual( tokens, [ SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "9", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "é", ".", ], ) ids = tokenizer.convert_tokens_to_ids(tokens) self.assertListEqual(ids, [ 8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4 ]) back_tokens = tokenizer.convert_ids_to_tokens(ids) self.assertListEqual( back_tokens, [ SPIECE_UNDERLINE + "I", SPIECE_UNDERLINE + "was", SPIECE_UNDERLINE + "b", "or", "n", SPIECE_UNDERLINE + "in", SPIECE_UNDERLINE + "", "<unk>", "2", "0", "0", "0", ",", SPIECE_UNDERLINE + "and", SPIECE_UNDERLINE + "this", SPIECE_UNDERLINE + "is", SPIECE_UNDERLINE + "f", "al", "s", "<unk>", ".", ], )
class XlnetProcessor(object): """Base class for data converters for sequence classification data sets.""" def __init__(self, vocab_path, do_lower_case): self.tokenizer = XLNetTokenizer(vocab_path, do_lower_case) def get_train(self, data_file): """Gets a collection of `InputExample`s for the train set.""" return self.read_data(data_file) def get_dev(self, data_file): """Gets a collection of `InputExample`s for the dev set.""" return self.read_data(data_file) def get_test(self, lines): return lines def get_labels(self): """Gets the list of labels for this data set.""" return [ "toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate" ] @classmethod def read_data(cls, input_file, quotechar=None): """Reads a tab separated value file.""" if 'pkl' in str(input_file): lines = load_pickle(input_file) else: lines = input_file return lines def truncate_seq_pair(self, tokens_a, tokens_b, max_length): # This is a simple heuristic which will always truncate the longer sequence # one token at a time. This makes more sense than truncating an equal percent # of tokens from each, since if one sequence is very short then each token # that's truncated likely contains more information than a longer sequence. while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break if len(tokens_a) > len(tokens_b): tokens_a.pop() else: tokens_b.pop() def create_examples(self, lines, example_type, cached_examples_file): ''' Creates examples for data ''' pbar = ProgressBar(n_total=len(lines)) if cached_examples_file.exists(): logger.info("Loading examples from cached file %s", cached_examples_file) examples = torch.load(cached_examples_file) else: examples = [] for i, line in enumerate(lines): guid = '%s-%d' % (example_type, i) text_a = line[0] label = line[1] if isinstance(label, str): label = [np.float(x) for x in label.split(",")] else: label = [np.float(x) for x in list(label)] text_b = None example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label) examples.append(example) pbar.batch_step(step=i, info={}, bar_type='create examples') logger.info("Saving examples into cached file %s", cached_examples_file) torch.save(examples, cached_examples_file) return examples def create_features(self, examples, max_seq_len, cached_features_file): ''' # The convention in BERT is: # (a) For sequence pairs: # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 # (b) For single sequences: # tokens: [CLS] the dog is hairy . [SEP] # type_ids: 0 0 0 0 0 0 0 ''' # Load data features from cache or dataset file pbar = ProgressBar(n_total=len(examples)) if cached_features_file.exists(): logger.info("Loading features from cached file %s", cached_features_file) features = torch.load(cached_features_file) else: features = [] pad_token = self.tokenizer.convert_tokens_to_ids( [self.tokenizer.pad_token])[0] cls_token = self.tokenizer.cls_token sep_token = self.tokenizer.sep_token cls_token_segment_id = 2 pad_token_segment_id = 4 for ex_id, example in enumerate(examples): tokens_a = self.tokenizer.tokenize(example.text_a) tokens_b = None label_id = example.label if example.text_b: tokens_b = self.tokenizer.tokenize(example.text_b) # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. # Account for [CLS], [SEP], [SEP] with "- 3" self.truncate_seq_pair(tokens_a, tokens_b, max_length=max_seq_len - 3) else: # Account for [CLS] and [SEP] with '-2' if len(tokens_a) > max_seq_len - 2: tokens_a = tokens_a[:max_seq_len - 2] # xlnet has a cls token at the end tokens = tokens_a + [sep_token] segment_ids = [0] * len(tokens) if tokens_b: tokens += tokens_b + [sep_token] segment_ids += [1] * (len(tokens_b) + 1) tokens += [cls_token] segment_ids += [cls_token_segment_id] input_ids = self.tokenizer.convert_tokens_to_ids(tokens) input_mask = [1] * len(input_ids) input_len = len(input_ids) padding_len = max_seq_len - len(input_ids) # pad on the left for xlnet input_ids = ([pad_token] * padding_len) + input_ids input_mask = ([0] * padding_len) + input_mask segment_ids = ([pad_token_segment_id] * padding_len) + segment_ids assert len(input_ids) == max_seq_len assert len(input_mask) == max_seq_len assert len(segment_ids) == max_seq_len if ex_id < 2: logger.info("*** Example ***") logger.info(f"guid: {example.guid}" % ()) logger.info( f"tokens: {' '.join([str(x) for x in tokens])}") logger.info( f"input_ids: {' '.join([str(x) for x in input_ids])}") logger.info( f"input_mask: {' '.join([str(x) for x in input_mask])}" ) logger.info( f"segment_ids: {' '.join([str(x) for x in segment_ids])}" ) feature = InputFeature(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_id=label_id, input_len=input_len) features.append(feature) pbar.batch_step(step=ex_id, info={}, bar_type='create features') logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) return features def create_dataset(self, features, is_sorted=False): # Convert to Tensors and build dataset if is_sorted: logger.info("sorted data by th length of input") features = sorted(features, key=lambda x: x.input_len, reverse=True) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) return dataset