def __init__(
        self,
        text_file,
        label_file,
        max_seq_length,
        tokenizer,
        num_samples=-1,
        pad_label='O',
        label_ids=None,
        ignore_extra_tokens=False,
        ignore_start_end=False,
        use_cache=False,
    ):

        if use_cache:
            # Cache features
            data_dir = os.path.dirname(text_file)
            filename = os.path.basename(text_file)

            if not filename.endswith('.txt'):
                raise ValueError("{text_file} should have extension .txt")

            tokenizer_type = type(tokenizer.tokenizer).__name__
            vocab_size = getattr(tokenizer, "vocab_size", 0)
            features_pkl = os.path.join(
                data_dir,
                "cached_{}_{}_{}_{}".format(filename, tokenizer_type,
                                            str(max_seq_length),
                                            str(vocab_size)),
            )
            label_ids_pkl = os.path.join(data_dir, "label_ids.pkl")

        if use_cache and os.path.exists(features_pkl) and os.path.exists(
                label_ids_pkl):
            # If text_file was already processed, load from pickle
            features = pickle.load(open(features_pkl, 'rb'))
            logging.info(f'features restored from {features_pkl}')

            label_ids = pickle.load(open(label_ids_pkl, 'rb'))
            logging.info(f'Labels to ids dict restored from {label_ids_pkl}')
        else:
            if num_samples == 0:
                raise ValueError("num_samples has to be positive", num_samples)

            with open(text_file, 'r') as f:
                text_lines = f.readlines()

            # Collect all possible labels
            unique_labels = set([])
            labels_lines = []
            with open(label_file, 'r') as f:
                for line in f:
                    line = line.strip().split()
                    labels_lines.append(line)
                    unique_labels.update(line)

            if len(labels_lines) != len(text_lines):
                raise ValueError(
                    "Labels file should contain labels for every word")

            if num_samples > 0:
                dataset = list(zip(text_lines, labels_lines))
                dataset = dataset[:num_samples]

                dataset = list(zip(*dataset))
                text_lines = dataset[0]
                labels_lines = dataset[1]

            # for dev/test sets use label mapping from training set
            if label_ids:
                if len(label_ids) != len(unique_labels):
                    logging.warning(
                        f'Not all labels from the specified' +
                        ' label_ids dictionary are present in the' +
                        ' current dataset. Using the provided' +
                        ' label_ids dictionary.')
                else:
                    logging.info(f'Using the provided label_ids dictionary.')
            else:
                logging.info(f'Creating a new label to label_id dictionary.' +
                             ' It\'s recommended to use label_ids generated' +
                             ' during training for dev/test sets to avoid' +
                             ' errors if some labels are not' +
                             ' present in the dev/test sets.' +
                             ' For training set label_ids should be None.')

                label_ids = {pad_label: 0}
                if pad_label in unique_labels:
                    unique_labels.remove(pad_label)
                for label in sorted(unique_labels):
                    label_ids[label] = len(label_ids)

            features = get_features(
                text_lines,
                max_seq_length,
                tokenizer,
                pad_label=pad_label,
                raw_labels=labels_lines,
                label_ids=label_ids,
                ignore_extra_tokens=ignore_extra_tokens,
                ignore_start_end=ignore_start_end,
            )

            if use_cache:
                pickle.dump(features, open(features_pkl, "wb"))
                logging.info(f'features saved to {features_pkl}')

                pickle.dump(label_ids, open(label_ids_pkl, "wb"))
                logging.info(f'labels to ids dict saved to {label_ids_pkl}')

        self.all_input_ids = features[0]
        self.all_segment_ids = features[1]
        self.all_input_mask = features[2]
        self.all_loss_mask = features[3]
        self.all_subtokens_mask = features[4]
        self.all_labels = features[5]
        self.label_ids = label_ids

        infold = text_file[:text_file.rfind('/')]
        merged_labels = itertools.chain.from_iterable(self.all_labels)
        logging.info('Three most popular labels')
        _, self.label_frequencies = get_label_stats(
            merged_labels, infold + '/label_stats.tsv')

        # save label_ids
        out = open(infold + '/label_ids.csv', 'w')
        labels, _ = zip(*sorted(self.label_ids.items(), key=lambda x: x[1]))
        out.write('\n'.join(labels))
        logging.info(f'Labels: {self.label_ids}')
        logging.info(f'Labels mapping saved to : {out.name}')
示例#2
0
    def __init__(self, dataset_name, data_dir, do_lower_case):
        if dataset_name == 'sst-2':
            self.data_dir = process_sst_2(data_dir)
            self.num_labels = 2
            self.eval_file = self.data_dir + '/dev.tsv'
        elif dataset_name == 'imdb':
            self.num_labels = 2
            self.data_dir = process_imdb(data_dir, do_lower_case)
            self.eval_file = self.data_dir + '/test.tsv'
        elif dataset_name == 'thucnews':
            self.num_labels = 14
            self.data_dir = process_thucnews(data_dir)
            self.eval_file = self.data_dir + '/test.tsv'
        elif dataset_name.startswith('nlu-'):
            if dataset_name.endswith('chat'):
                self.data_dir = f'{data_dir}/ChatbotCorpus.json'
                self.num_labels = 2
            elif dataset_name.endswith('ubuntu'):
                self.data_dir = f'{data_dir}/AskUbuntuCorpus.json'
                self.num_labels = 5
            elif dataset_name.endswith('web'):
                data_dir = f'{data_dir}/WebApplicationsCorpus.json'
                self.num_labels = 8
            self.data_dir = process_nlu(data_dir,
                                        do_lower_case,
                                        dataset_name=dataset_name)
            self.eval_file = self.data_dir + '/test.tsv'
        elif dataset_name.startswith('jarvis'):
            self.data_dir = process_jarvis_datasets(
                data_dir,
                do_lower_case,
                dataset_name,
                modes=['train', 'test', 'eval'],
                ignore_prev_intent=False)

            intents = get_intent_labels(f'{self.data_dir}/dict.intents.csv')
            self.num_labels = len(intents)
        else:
            raise ValueError("Looks like you passed a dataset name that isn't "
                             "already supported by NeMo. Please make sure "
                             "that you build the preprocessing method for it.")

        self.train_file = self.data_dir + '/train.tsv'

        for mode in ['train', 'test', 'eval']:

            if not if_exist(self.data_dir, [f'{mode}.tsv']):
                logging.info(f' Stats calculation for {mode} mode'
                             f' is skipped as {mode}.tsv was not found.')
                continue

            input_file = f'{self.data_dir}/{mode}.tsv'
            with open(input_file, 'r') as f:
                input_lines = f.readlines()[1:]  # Skipping headers at index 0

            queries, raw_sentences = [], []
            for input_line in input_lines:
                parts = input_line.strip().split()
                raw_sentences.append(int(parts[-1]))
                queries.append(' '.join(parts[:-1]))

            infold = input_file[:input_file.rfind('/')]

            logging.info(f'Three most popular classes during {mode}ing')
            total_sents, sent_label_freq = get_label_stats(
                raw_sentences, infold + f'/{mode}_sentence_stats.tsv')

            if mode == 'train':
                self.class_weights = calc_class_weights(sent_label_freq)
                logging.info(f'Class weights are - {self.class_weights}')

            logging.info(f'Total Sentences - {total_sents}')
            logging.info(f'Sentence class frequencies - {sent_label_freq}')