コード例 #1
0
ファイル: token_classification.py プロジェクト: rssaketh/NeMo
def create_pipeline(
    pad_label=args.none_label,
    max_seq_length=args.max_seq_length,
    batch_size=args.batch_size,
    num_gpus=args.num_gpus,
    mode='train',
    batches_per_step=args.batches_per_step,
    label_ids=None,
    ignore_extra_tokens=args.ignore_extra_tokens,
    ignore_start_end=args.ignore_start_end,
    use_cache=args.use_cache,
    dropout=args.fc_dropout,
    num_layers=args.num_fc_layers,
    classifier=TokenClassifier,
):

    logging.info(f"Loading {mode} data...")
    shuffle = args.shuffle_data if mode == 'train' else False

    text_file = f'{args.data_dir}/text_{mode}.txt'
    label_file = f'{args.data_dir}/labels_{mode}.txt'

    if not (os.path.exists(text_file) or (os.path.exists(label_file))):
        raise FileNotFoundError(f'{text_file} or {label_file} not found. \
           The data should be splitted into 2 files: text.txt and labels.txt. \
           Each line of the text.txt file contains text sequences, where words\
           are separated with spaces. The labels.txt file contains \
           corresponding labels for each word in text.txt, the labels are \
           separated with spaces. Each line of the files should follow the \
           format:  \
           [WORD] [SPACE] [WORD] [SPACE] [WORD] (for text.txt) and \
           [LABEL] [SPACE] [LABEL] [SPACE] [LABEL] (for labels.txt).')

    data_layer = BertTokenClassificationDataLayer(
        tokenizer=tokenizer,
        text_file=text_file,
        label_file=label_file,
        pad_label=pad_label,
        label_ids=label_ids,
        max_seq_length=max_seq_length,
        batch_size=batch_size,
        shuffle=shuffle,
        ignore_extra_tokens=ignore_extra_tokens,
        ignore_start_end=ignore_start_end,
        use_cache=use_cache,
    )

    (input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask,
     labels) = data_layer()

    if mode == 'train':
        label_ids = data_layer.dataset.label_ids
        class_weights = None

        if args.use_weighted_loss:
            logging.info(f"Using weighted loss")
            label_freqs = data_layer.dataset.label_frequencies
            class_weights = calc_class_weights(label_freqs)

        classifier = classifier(hidden_size=hidden_size,
                                num_classes=len(label_ids),
                                dropout=dropout,
                                num_layers=num_layers)

        task_loss = CrossEntropyLossNM(logits_ndim=3, weight=class_weights)

    hidden_states = model(input_ids=input_ids,
                          token_type_ids=input_type_ids,
                          attention_mask=input_mask)
    logits = classifier(hidden_states=hidden_states)

    if mode == 'train':
        loss = task_loss(logits=logits, labels=labels, loss_mask=loss_mask)
        steps_per_epoch = len(data_layer) // (batch_size * num_gpus *
                                              batches_per_step)
        tensors_to_evaluate = [loss, logits]
        return tensors_to_evaluate, loss, steps_per_epoch, label_ids, classifier
    else:
        tensors_to_evaluate = [logits, labels, subtokens_mask]
        return tensors_to_evaluate, data_layer
コード例 #2
0
    def __init__(self, dataset_name, data_dir, do_lower_case):
        if dataset_name == 'sst-2':
            self.data_dir = process_sst_2(data_dir)
            self.num_labels = 2
            self.eval_file = self.data_dir + '/dev.tsv'
        elif dataset_name == 'imdb':
            self.num_labels = 2
            self.data_dir = process_imdb(data_dir, do_lower_case)
            self.eval_file = self.data_dir + '/test.tsv'
        elif dataset_name == 'thucnews':
            self.num_labels = 14
            self.data_dir = process_thucnews(data_dir)
            self.eval_file = self.data_dir + '/test.tsv'
        elif dataset_name.startswith('nlu-'):
            if dataset_name.endswith('chat'):
                self.data_dir = f'{data_dir}/ChatbotCorpus.json'
                self.num_labels = 2
            elif dataset_name.endswith('ubuntu'):
                self.data_dir = f'{data_dir}/AskUbuntuCorpus.json'
                self.num_labels = 5
            elif dataset_name.endswith('web'):
                data_dir = f'{data_dir}/WebApplicationsCorpus.json'
                self.num_labels = 8
            self.data_dir = process_nlu(data_dir,
                                        do_lower_case,
                                        dataset_name=dataset_name)
            self.eval_file = self.data_dir + '/test.tsv'
        elif dataset_name.startswith('jarvis'):
            self.data_dir = process_jarvis_datasets(
                data_dir,
                do_lower_case,
                dataset_name,
                modes=['train', 'test', 'eval'],
                ignore_prev_intent=False)

            intents = get_intent_labels(f'{self.data_dir}/dict.intents.csv')
            self.num_labels = len(intents)
        else:
            raise ValueError("Looks like you passed a dataset name that isn't "
                             "already supported by NeMo. Please make sure "
                             "that you build the preprocessing method for it.")

        self.train_file = self.data_dir + '/train.tsv'

        for mode in ['train', 'test', 'eval']:

            if not if_exist(self.data_dir, [f'{mode}.tsv']):
                logging.info(f' Stats calculation for {mode} mode'
                             f' is skipped as {mode}.tsv was not found.')
                continue

            input_file = f'{self.data_dir}/{mode}.tsv'
            with open(input_file, 'r') as f:
                input_lines = f.readlines()[1:]  # Skipping headers at index 0

            queries, raw_sentences = [], []
            for input_line in input_lines:
                parts = input_line.strip().split()
                raw_sentences.append(int(parts[-1]))
                queries.append(' '.join(parts[:-1]))

            infold = input_file[:input_file.rfind('/')]

            logging.info(f'Three most popular classes during {mode}ing')
            total_sents, sent_label_freq = get_label_stats(
                raw_sentences, infold + f'/{mode}_sentence_stats.tsv')

            if mode == 'train':
                self.class_weights = calc_class_weights(sent_label_freq)
                logging.info(f'Class weights are - {self.class_weights}')

            logging.info(f'Total Sentences - {total_sents}')
            logging.info(f'Sentence class frequencies - {sent_label_freq}')