コード例 #1
0
ファイル: tagger_dataset.py プロジェクト: manneh/NeMo
    def __init__(self, w_words, s_words, direction, do_basic_tokenize=False):
        # Build input_words and labels
        input_words, labels = [], []
        # Task Prefix
        if direction == constants.INST_BACKWARD:
            input_words.append(constants.ITN_PREFIX)
        if direction == constants.INST_FORWARD:
            input_words.append(constants.TN_PREFIX)
        labels.append(constants.TASK_TAG)
        # Main Content
        for w_word, s_word in zip(w_words, s_words):
            # Basic tokenization (if enabled)
            if do_basic_tokenize:
                w_word = ' '.join(basic_tokenize(w_word, self.lang))
                if not s_word in constants.SPECIAL_WORDS:
                    s_word = ' '.join(basic_tokenize(s_word, self.lang))
            # Update input_words and labels
            if s_word == constants.SIL_WORD and direction == constants.INST_BACKWARD:
                continue

            if s_word in constants.SPECIAL_WORDS:
                input_words.append(w_word)
                labels.append(constants.SAME_TAG)
            else:
                if direction == constants.INST_BACKWARD:
                    input_words.append(s_word)
                if direction == constants.INST_FORWARD:
                    input_words.append(w_word)
                labels.append(constants.TRANSFORM_TAG)
        self.input_words = input_words
        self.labels = labels
コード例 #2
0
    def __init__(
        self, w_words, s_words, inst_dir, start_idx, end_idx, lang, semiotic_class=None, do_basic_tokenize=False
    ):
        start_idx = max(start_idx, 0)
        end_idx = min(end_idx, len(w_words))
        ctx_size = constants.DECODE_CTX_SIZE
        extra_id_0 = constants.EXTRA_ID_0
        extra_id_1 = constants.EXTRA_ID_1

        # Extract center words
        c_w_words = w_words[start_idx:end_idx]
        c_s_words = s_words[start_idx:end_idx]

        # Extract context
        w_left = w_words[max(0, start_idx - ctx_size) : start_idx]
        w_right = w_words[end_idx : end_idx + ctx_size]
        s_left = s_words[max(0, start_idx - ctx_size) : start_idx]
        s_right = s_words[end_idx : end_idx + ctx_size]

        # Process sil words and self words
        for jx in range(len(s_left)):
            if s_left[jx] == constants.SIL_WORD:
                s_left[jx] = ''
            if s_left[jx] == constants.SELF_WORD:
                s_left[jx] = w_left[jx]
        for jx in range(len(s_right)):
            if s_right[jx] == constants.SIL_WORD:
                s_right[jx] = ''
            if s_right[jx] == constants.SELF_WORD:
                s_right[jx] = w_right[jx]
        for jx in range(len(c_s_words)):
            if c_s_words[jx] == constants.SIL_WORD:
                c_s_words[jx] = ''
                if inst_dir == constants.INST_BACKWARD:
                    c_w_words[jx] = ''
            if c_s_words[jx] == constants.SELF_WORD:
                c_s_words[jx] = c_w_words[jx]

        # Extract input_words and output_words
        if do_basic_tokenize:
            c_w_words = basic_tokenize(' '.join(c_w_words), lang)
            c_s_words = basic_tokenize(' '.join(c_s_words), lang)
        w_input = w_left + [extra_id_0] + c_w_words + [extra_id_1] + w_right
        s_input = s_left + [extra_id_0] + c_s_words + [extra_id_1] + s_right
        if inst_dir == constants.INST_BACKWARD:
            input_center_words = c_s_words
            input_words = [constants.ITN_PREFIX] + s_input
            output_words = c_w_words
        if inst_dir == constants.INST_FORWARD:
            input_center_words = c_w_words
            input_words = [constants.TN_PREFIX] + w_input
            output_words = c_s_words
        # Finalize
        self.input_str = ' '.join(input_words)
        self.input_center_str = ' '.join(input_center_words)
        self.output_str = ' '.join(output_words)
        self.direction = inst_dir
        self.semiotic_class = semiotic_class
コード例 #3
0
    def input_preprocessing(self, sents):
        """ Function for preprocessing the input texts. The function first does
        some basic tokenization. For English, it then also processes Greek letters
        such as Δ or λ (if any).

        Args:
            sents: A list of input texts.

        Returns: A list of preprocessed input texts.
        """
        # Basic Preprocessing and Tokenization
        if self.lang == constants.ENGLISH:
            for ix, sent in enumerate(sents):
                sents[ix] = sents[ix].replace('+', ' plus ')
                sents[ix] = sents[ix].replace('=', ' equals ')
                sents[ix] = sents[ix].replace('@', ' at ')
                sents[ix] = sents[ix].replace('*', ' times ')
        sents = [basic_tokenize(sent, self.lang) for sent in sents]

        # Greek letters processing
        if self.lang == constants.ENGLISH:
            for ix, sent in enumerate(sents):
                for jx, tok in enumerate(sent):
                    if tok in constants.EN_GREEK_TO_SPOKEN:
                        sents[ix][jx] = constants.EN_GREEK_TO_SPOKEN[tok]

        return sents
コード例 #4
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')
    lang = cfg.lang
    tagger_trainer, tagger_model = instantiate_model_and_trainer(
        cfg, TAGGER_MODEL, False)
    decoder_trainer, decoder_model = instantiate_model_and_trainer(
        cfg, DECODER_MODEL, False)
    tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang)

    if not cfg.inference.interactive:
        # Setup test_dataset
        test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path,
                                                    cfg.data.test_ds.mode,
                                                    lang)
        results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size,
                                    cfg.inference.errors_log_fp)
        print(f'\nTest results: {results}')
    else:
        while True:
            test_input = input('Input a test input:')
            test_input = ' '.join(basic_tokenize(test_input, lang))
            outputs = tn_model._infer(
                [test_input, test_input],
                [constants.INST_BACKWARD, constants.INST_FORWARD])[-1]
            print(f'Prediction (ITN): {outputs[0]}')
            print(f'Prediction (TN): {outputs[1]}')

            should_continue = input('\nContinue (y/n): ').strip().lower()
            if should_continue.startswith('n'):
                break
コード例 #5
0
    def __init__(self,
                 input_file: str,
                 mode: str,
                 lang: str,
                 keep_puncts: bool = False):
        self.lang = lang
        insts = read_data_file(input_file)

        # Build inputs and targets
        self.directions, self.inputs, self.targets = [], [], []
        for (_, w_words, s_words) in insts:
            # Extract words that are not punctuations
            processed_w_words, processed_s_words = [], []
            for w_word, s_word in zip(w_words, s_words):
                if s_word == constants.SIL_WORD:
                    if keep_puncts:
                        processed_w_words.append(w_word)
                        processed_s_words.append(w_word)
                    continue
                if s_word == constants.SELF_WORD:
                    processed_s_words.append(w_word)
                if not s_word in constants.SPECIAL_WORDS:
                    processed_s_words.append(s_word)
                processed_w_words.append(w_word)
            # Create examples
            for direction in constants.INST_DIRECTIONS:
                if direction == constants.INST_BACKWARD:
                    if mode == constants.TN_MODE:
                        continue
                    input_words = processed_s_words
                    output_words = processed_w_words
                if direction == constants.INST_FORWARD:
                    if mode == constants.ITN_MODE:
                        continue
                    input_words = w_words
                    output_words = processed_s_words
                # Basic tokenization
                input_words = basic_tokenize(' '.join(input_words), lang)
                output_words = basic_tokenize(' '.join(output_words), lang)
                # Update self.directions, self.inputs, self.targets
                self.directions.append(direction)
                self.inputs.append(' '.join(input_words))
                self.targets.append(' '.join(output_words))
        self.examples = list(zip(self.directions, self.inputs, self.targets))
コード例 #6
0
ファイル: duplex_tn.py プロジェクト: stjordanis/NeMo
    def _infer(self, sents: List[str], inst_directions: List[str]):
        """ Main function for Inference
        Args:
            sents: A list of input texts.
            inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN).

        Returns:
            tag_preds: A list of lists where each list contains the tag predictions from the tagger for an input text.
            output_spans: A list of lists where each list contains the decoded semiotic spans from the decoder for an input text.
            final_outputs: A list of str where each str is the final output text for an input text.
        """
        # Preprocessing
        sents = self.input_preprocessing(list(sents))

        # Tagging
        tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer(
            sents, inst_directions)
        output_spans = self.decoder._infer(sents, nb_spans, span_starts,
                                           span_ends, inst_directions)

        # Preprare final outputs
        final_outputs = []
        for ix, (sent, tags) in enumerate(zip(sents, tag_preds)):
            cur_words, jx, span_idx = [], 0, 0
            cur_spans = output_spans[ix]
            while jx < len(sent):
                tag, word = tags[jx], sent[jx]
                if constants.SAME_TAG in tag:
                    cur_words.append(word)
                    jx += 1
                elif constants.PUNCT_TAG in tag:
                    jx += 1
                else:
                    jx += 1
                    cur_words.append(cur_spans[span_idx])
                    span_idx += 1
                    while jx < len(sent) and tags[
                            jx] == constants.I_PREFIX + constants.TRANSFORM_TAG:
                        jx += 1
            cur_output_str = ' '.join(cur_words)
            cur_output_str = ' '.join(basic_tokenize(cur_output_str,
                                                     self.lang))
            final_outputs.append(cur_output_str)
        return tag_preds, output_spans, final_outputs
コード例 #7
0
ファイル: test_dataset.py プロジェクト: carolmanderson/NeMo
    def __init__(self, input_file: str, mode: str, lang: str):
        self.lang = lang
        insts = read_data_file(input_file, lang=lang)

        # Build inputs and targets
        self.directions, self.inputs, self.targets, self.classes, self.nb_spans, self.span_starts, self.span_ends = (
            [],
            [],
            [],
            [],
            [],
            [],
            [],
        )
        for (classes, w_words, s_words) in insts:
            # Extract words that are not punctuations
            for direction in constants.INST_DIRECTIONS:
                if direction == constants.INST_BACKWARD:
                    if mode == constants.TN_MODE:
                        continue

                    # ITN mode
                    (
                        processed_w_words,
                        processed_s_words,
                        processed_classes,
                        processed_nb_spans,
                        processed_s_span_starts,
                        processed_s_span_ends,
                    ) = ([], [], [], 0, [], [])
                    s_word_idx = 0
                    for cls, w_word, s_word in zip(classes, w_words, s_words):
                        if s_word == constants.SIL_WORD:
                            continue
                        elif s_word == constants.SELF_WORD:
                            processed_s_words.append(w_word)
                        else:
                            processed_s_words.append(s_word)

                        processed_nb_spans += 1
                        processed_classes.append(cls)
                        processed_s_span_starts.append(s_word_idx)
                        s_word_idx += len(
                            basic_tokenize(processed_s_words[-1],
                                           lang=self.lang))
                        processed_s_span_ends.append(s_word_idx)
                        processed_w_words.append(w_word)

                    self.span_starts.append(processed_s_span_starts)
                    self.span_ends.append(processed_s_span_ends)
                    self.classes.append(processed_classes)
                    self.nb_spans.append(processed_nb_spans)
                    # Basic tokenization
                    input_words = basic_tokenize(' '.join(processed_s_words),
                                                 lang)
                    # Update self.directions, self.inputs, self.targets
                    self.directions.append(direction)
                    self.inputs.append(' '.join(input_words))
                    self.targets.append(
                        processed_w_words
                    )  # is list of lists where inner list contains target tokens (not words)

                # TN mode
                elif direction == constants.INST_FORWARD:
                    if mode == constants.ITN_MODE:
                        continue
                    (
                        processed_w_words,
                        processed_s_words,
                        processed_classes,
                        processed_nb_spans,
                        w_span_starts,
                        w_span_ends,
                    ) = ([], [], [], 0, [], [])
                    w_word_idx = 0
                    for cls, w_word, s_word in zip(classes, w_words, s_words):

                        # TN forward mode
                        if s_word in constants.SPECIAL_WORDS:
                            processed_s_words.append(w_word)
                        else:
                            processed_s_words.append(s_word)

                        w_span_starts.append(w_word_idx)
                        w_word_idx += len(
                            basic_tokenize(w_word, lang=self.lang))
                        w_span_ends.append(w_word_idx)
                        processed_nb_spans += 1
                        processed_classes.append(cls)
                        processed_w_words.append(w_word)

                    self.span_starts.append(w_span_starts)
                    self.span_ends.append(w_span_ends)
                    self.classes.append(processed_classes)
                    self.nb_spans.append(processed_nb_spans)
                    # Basic tokenization
                    input_words = basic_tokenize(' '.join(processed_w_words),
                                                 lang)
                    # Update self.directions, self.inputs, self.targets
                    self.directions.append(direction)
                    self.inputs.append(' '.join(input_words))
                    self.targets.append(
                        processed_s_words
                    )  # is list of lists where inner list contains target tokens (not words)

        self.examples = list(
            zip(
                self.directions,
                self.inputs,
                self.targets,
                self.classes,
                self.nb_spans,
                self.span_starts,
                self.span_ends,
            ))
コード例 #8
0
    def evaluate(self,
                 dataset: TextNormalizationTestDataset,
                 batch_size: int,
                 errors_log_fp: str,
                 verbose: bool = True):
        """ Function for evaluating the performance of the model on a dataset

        Args:
            dataset: The dataset to be used for evaluation.
            batch_size: Batch size to use during inference. You can set it to be 1
                (no batching) if you want to measure the running time of the model
                per individual example (assuming requests are coming to the model one-by-one).
            errors_log_fp: Path to the file for logging the errors
            verbose: if true prints and logs various evaluation results

        Returns:
            results: A Dict containing the evaluation results (e.g., accuracy, running time)
        """
        results = {}
        error_f = open(errors_log_fp, 'w+')

        # Apply the model on the dataset
        (
            all_run_times,
            all_dirs,
            all_inputs,
            all_targets,
            all_classes,
            all_nb_spans,
            all_span_starts,
            all_span_ends,
            all_output_spans,
        ) = ([], [], [], [], [], [], [], [], [])
        all_tag_preds, all_final_preds = [], []
        nb_iters = int(ceil(len(dataset) / batch_size))
        for i in tqdm(range(nb_iters)):
            start_idx = i * batch_size
            end_idx = (i + 1) * batch_size
            batch_insts = dataset[start_idx:end_idx]
            (
                batch_dirs,
                batch_inputs,
                batch_targets,
                batch_classes,
                batch_nb_spans,
                batch_span_starts,
                batch_span_ends,
            ) = zip(*batch_insts)
            # Inference and Running Time Measurement
            batch_start_time = perf_counter()
            batch_tag_preds, batch_output_spans, batch_final_preds = self._infer(
                batch_inputs, batch_dirs)
            batch_run_time = (perf_counter() -
                              batch_start_time) * 1000  # milliseconds
            all_run_times.append(batch_run_time)
            # Update all_dirs, all_inputs, all_tag_preds, all_final_preds and all_targets
            all_dirs.extend(batch_dirs)
            all_inputs.extend(batch_inputs)
            all_tag_preds.extend(batch_tag_preds)
            all_final_preds.extend(batch_final_preds)
            all_targets.extend(batch_targets)
            all_classes.extend(batch_classes)
            all_nb_spans.extend(batch_nb_spans)
            all_span_starts.extend(batch_span_starts)
            all_span_ends.extend(batch_span_ends)
            all_output_spans.extend(batch_output_spans)

        # Metrics
        tn_error_ctx, itn_error_ctx = 0, 0
        for direction in constants.INST_DIRECTIONS:
            (
                cur_dirs,
                cur_inputs,
                cur_tag_preds,
                cur_final_preds,
                cur_targets,
                cur_classes,
                cur_nb_spans,
                cur_span_starts,
                cur_span_ends,
                cur_output_spans,
            ) = ([], [], [], [], [], [], [], [], [], [])
            for dir, _input, tag_pred, final_pred, target, cls, nb_spans, span_starts, span_ends, output_spans in zip(
                    all_dirs,
                    all_inputs,
                    all_tag_preds,
                    all_final_preds,
                    all_targets,
                    all_classes,
                    all_nb_spans,
                    all_span_starts,
                    all_span_ends,
                    all_output_spans,
            ):
                if dir == direction:
                    cur_dirs.append(dir)
                    cur_inputs.append(_input)
                    cur_tag_preds.append(tag_pred)
                    cur_final_preds.append(final_pred)
                    cur_targets.append(target)
                    cur_classes.append(cls)
                    cur_nb_spans.append(nb_spans)
                    cur_span_starts.append(span_starts)
                    cur_span_ends.append(span_ends)
                    cur_output_spans.append(output_spans)
            nb_instances = len(cur_final_preds)
            cur_targets_sent = [" ".join(x) for x in cur_targets]
            sent_accuracy = TextNormalizationTestDataset.compute_sent_accuracy(
                cur_final_preds, cur_targets_sent, cur_dirs, self.lang)
            class_accuracy = TextNormalizationTestDataset.compute_class_accuracy(
                self.input_preprocessing(list(cur_inputs)),
                cur_targets,
                cur_tag_preds,
                cur_dirs,
                cur_output_spans,
                cur_classes,
                cur_nb_spans,
                cur_span_starts,
                cur_span_ends,
                self.lang,
            )
            if verbose:
                logging.info(
                    f'\n============ Direction {direction} ============')
                logging.info(f'Sentence Accuracy: {sent_accuracy}')
                logging.info(f'nb_instances: {nb_instances}')
                if not isinstance(class_accuracy, str):
                    log_class_accuracies = ""
                    for key, value in class_accuracy.items():
                        log_class_accuracies += f"\n\t{key}:\t{value[0]}\t{value[1]}/{value[2]}"
                else:
                    log_class_accuracies = class_accuracy
                logging.info(f'class accuracies: {log_class_accuracies}')
            # Update results
            results[direction] = {
                'sent_accuracy': sent_accuracy,
                'nb_instances': nb_instances,
                "class_accuracy": log_class_accuracies,
            }
            # Write errors to log file
            for _input, tag_pred, final_pred, target in zip(
                    cur_inputs, cur_tag_preds, cur_final_preds,
                    cur_targets_sent):
                if not TextNormalizationTestDataset.is_same(
                        final_pred, target, direction, self.lang):
                    if direction == constants.INST_BACKWARD:
                        error_f.write('Backward Problem (ITN)\n')
                        itn_error_ctx += 1
                    elif direction == constants.INST_FORWARD:
                        error_f.write('Forward Problem (TN)\n')
                        tn_error_ctx += 1
                    formatted_input_str = get_formatted_string(
                        basic_tokenize(_input, lang=self.lang))
                    formatted_tag_pred_str = get_formatted_string(tag_pred)
                    error_f.write(f'Original Input : {_input}\n')
                    error_f.write(f'Input          : {formatted_input_str}\n')
                    error_f.write(
                        f'Predicted Tags : {formatted_tag_pred_str}\n')
                    error_f.write(f'Predicted Str  : {final_pred}\n')
                    error_f.write(f'Ground-Truth   : {target}\n')
                    error_f.write('\n')
            results['itn_error_ctx'] = itn_error_ctx
            results['tn_error_ctx'] = tn_error_ctx

        # Running Time
        avg_running_time = np.average(all_run_times) / batch_size  # in ms
        if verbose:
            logging.info(
                f'Average running time (normalized by batch size): {avg_running_time} ms'
            )
        results['running_time'] = avg_running_time

        # Close log file
        error_f.close()

        return results
コード例 #9
0
    parser.add_argument('--lang',
                        type=str,
                        default=constants.ENGLISH,
                        choices=constants.SUPPORTED_LANGS,
                        help='Language')
    args = parser.parse_args()

    # Create the output dir (if not exist)
    if not isdir(args.output_dir):
        mkdir(args.output_dir)

    # Processing
    train, dev, test = read_google_data(args.data_dir, args.lang)
    for split, data in zip(constants.SPLIT_NAMES, [train, dev, test]):
        output_f = open(join(args.output_dir, f'{split}.tsv'),
                        'w+',
                        encoding='utf-8')
        for inst in data:
            cur_classes, cur_tokens, cur_outputs = inst
            for c, t, o in zip(cur_classes, cur_tokens, cur_outputs):
                t = ' '.join(basic_tokenize(t, args.lang))
                if not o in constants.SPECIAL_WORDS:
                    o_tokens = basic_tokenize(o, args.lang)
                    o_tokens = [
                        o_tok for o_tok in o_tokens
                        if o_tok != constants.SIL_WORD
                    ]
                    o = ' '.join(o_tokens)
                output_f.write(f'{c}\t{t}\t{o}\n')
            output_f.write('<eos>\t<eos>\n')