def __init__(self, w_words, s_words, direction, do_basic_tokenize=False): # Build input_words and labels input_words, labels = [], [] # Task Prefix if direction == constants.INST_BACKWARD: input_words.append(constants.ITN_PREFIX) if direction == constants.INST_FORWARD: input_words.append(constants.TN_PREFIX) labels.append(constants.TASK_TAG) # Main Content for w_word, s_word in zip(w_words, s_words): # Basic tokenization (if enabled) if do_basic_tokenize: w_word = ' '.join(basic_tokenize(w_word, self.lang)) if not s_word in constants.SPECIAL_WORDS: s_word = ' '.join(basic_tokenize(s_word, self.lang)) # Update input_words and labels if s_word == constants.SIL_WORD and direction == constants.INST_BACKWARD: continue if s_word in constants.SPECIAL_WORDS: input_words.append(w_word) labels.append(constants.SAME_TAG) else: if direction == constants.INST_BACKWARD: input_words.append(s_word) if direction == constants.INST_FORWARD: input_words.append(w_word) labels.append(constants.TRANSFORM_TAG) self.input_words = input_words self.labels = labels
def __init__( self, w_words, s_words, inst_dir, start_idx, end_idx, lang, semiotic_class=None, do_basic_tokenize=False ): start_idx = max(start_idx, 0) end_idx = min(end_idx, len(w_words)) ctx_size = constants.DECODE_CTX_SIZE extra_id_0 = constants.EXTRA_ID_0 extra_id_1 = constants.EXTRA_ID_1 # Extract center words c_w_words = w_words[start_idx:end_idx] c_s_words = s_words[start_idx:end_idx] # Extract context w_left = w_words[max(0, start_idx - ctx_size) : start_idx] w_right = w_words[end_idx : end_idx + ctx_size] s_left = s_words[max(0, start_idx - ctx_size) : start_idx] s_right = s_words[end_idx : end_idx + ctx_size] # Process sil words and self words for jx in range(len(s_left)): if s_left[jx] == constants.SIL_WORD: s_left[jx] = '' if s_left[jx] == constants.SELF_WORD: s_left[jx] = w_left[jx] for jx in range(len(s_right)): if s_right[jx] == constants.SIL_WORD: s_right[jx] = '' if s_right[jx] == constants.SELF_WORD: s_right[jx] = w_right[jx] for jx in range(len(c_s_words)): if c_s_words[jx] == constants.SIL_WORD: c_s_words[jx] = '' if inst_dir == constants.INST_BACKWARD: c_w_words[jx] = '' if c_s_words[jx] == constants.SELF_WORD: c_s_words[jx] = c_w_words[jx] # Extract input_words and output_words if do_basic_tokenize: c_w_words = basic_tokenize(' '.join(c_w_words), lang) c_s_words = basic_tokenize(' '.join(c_s_words), lang) w_input = w_left + [extra_id_0] + c_w_words + [extra_id_1] + w_right s_input = s_left + [extra_id_0] + c_s_words + [extra_id_1] + s_right if inst_dir == constants.INST_BACKWARD: input_center_words = c_s_words input_words = [constants.ITN_PREFIX] + s_input output_words = c_w_words if inst_dir == constants.INST_FORWARD: input_center_words = c_w_words input_words = [constants.TN_PREFIX] + w_input output_words = c_s_words # Finalize self.input_str = ' '.join(input_words) self.input_center_str = ' '.join(input_center_words) self.output_str = ' '.join(output_words) self.direction = inst_dir self.semiotic_class = semiotic_class
def input_preprocessing(self, sents): """ Function for preprocessing the input texts. The function first does some basic tokenization. For English, it then also processes Greek letters such as Δ or λ (if any). Args: sents: A list of input texts. Returns: A list of preprocessed input texts. """ # Basic Preprocessing and Tokenization if self.lang == constants.ENGLISH: for ix, sent in enumerate(sents): sents[ix] = sents[ix].replace('+', ' plus ') sents[ix] = sents[ix].replace('=', ' equals ') sents[ix] = sents[ix].replace('@', ' at ') sents[ix] = sents[ix].replace('*', ' times ') sents = [basic_tokenize(sent, self.lang) for sent in sents] # Greek letters processing if self.lang == constants.ENGLISH: for ix, sent in enumerate(sents): for jx, tok in enumerate(sent): if tok in constants.EN_GREEK_TO_SPOKEN: sents[ix][jx] = constants.EN_GREEK_TO_SPOKEN[tok] return sents
def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') lang = cfg.lang tagger_trainer, tagger_model = instantiate_model_and_trainer( cfg, TAGGER_MODEL, False) decoder_trainer, decoder_model = instantiate_model_and_trainer( cfg, DECODER_MODEL, False) tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang) if not cfg.inference.interactive: # Setup test_dataset test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.data.test_ds.mode, lang) results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.inference.errors_log_fp) print(f'\nTest results: {results}') else: while True: test_input = input('Input a test input:') test_input = ' '.join(basic_tokenize(test_input, lang)) outputs = tn_model._infer( [test_input, test_input], [constants.INST_BACKWARD, constants.INST_FORWARD])[-1] print(f'Prediction (ITN): {outputs[0]}') print(f'Prediction (TN): {outputs[1]}') should_continue = input('\nContinue (y/n): ').strip().lower() if should_continue.startswith('n'): break
def __init__(self, input_file: str, mode: str, lang: str, keep_puncts: bool = False): self.lang = lang insts = read_data_file(input_file) # Build inputs and targets self.directions, self.inputs, self.targets = [], [], [] for (_, w_words, s_words) in insts: # Extract words that are not punctuations processed_w_words, processed_s_words = [], [] for w_word, s_word in zip(w_words, s_words): if s_word == constants.SIL_WORD: if keep_puncts: processed_w_words.append(w_word) processed_s_words.append(w_word) continue if s_word == constants.SELF_WORD: processed_s_words.append(w_word) if not s_word in constants.SPECIAL_WORDS: processed_s_words.append(s_word) processed_w_words.append(w_word) # Create examples for direction in constants.INST_DIRECTIONS: if direction == constants.INST_BACKWARD: if mode == constants.TN_MODE: continue input_words = processed_s_words output_words = processed_w_words if direction == constants.INST_FORWARD: if mode == constants.ITN_MODE: continue input_words = w_words output_words = processed_s_words # Basic tokenization input_words = basic_tokenize(' '.join(input_words), lang) output_words = basic_tokenize(' '.join(output_words), lang) # Update self.directions, self.inputs, self.targets self.directions.append(direction) self.inputs.append(' '.join(input_words)) self.targets.append(' '.join(output_words)) self.examples = list(zip(self.directions, self.inputs, self.targets))
def _infer(self, sents: List[str], inst_directions: List[str]): """ Main function for Inference Args: sents: A list of input texts. inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN). Returns: tag_preds: A list of lists where each list contains the tag predictions from the tagger for an input text. output_spans: A list of lists where each list contains the decoded semiotic spans from the decoder for an input text. final_outputs: A list of str where each str is the final output text for an input text. """ # Preprocessing sents = self.input_preprocessing(list(sents)) # Tagging tag_preds, nb_spans, span_starts, span_ends = self.tagger._infer( sents, inst_directions) output_spans = self.decoder._infer(sents, nb_spans, span_starts, span_ends, inst_directions) # Preprare final outputs final_outputs = [] for ix, (sent, tags) in enumerate(zip(sents, tag_preds)): cur_words, jx, span_idx = [], 0, 0 cur_spans = output_spans[ix] while jx < len(sent): tag, word = tags[jx], sent[jx] if constants.SAME_TAG in tag: cur_words.append(word) jx += 1 elif constants.PUNCT_TAG in tag: jx += 1 else: jx += 1 cur_words.append(cur_spans[span_idx]) span_idx += 1 while jx < len(sent) and tags[ jx] == constants.I_PREFIX + constants.TRANSFORM_TAG: jx += 1 cur_output_str = ' '.join(cur_words) cur_output_str = ' '.join(basic_tokenize(cur_output_str, self.lang)) final_outputs.append(cur_output_str) return tag_preds, output_spans, final_outputs
def __init__(self, input_file: str, mode: str, lang: str): self.lang = lang insts = read_data_file(input_file, lang=lang) # Build inputs and targets self.directions, self.inputs, self.targets, self.classes, self.nb_spans, self.span_starts, self.span_ends = ( [], [], [], [], [], [], [], ) for (classes, w_words, s_words) in insts: # Extract words that are not punctuations for direction in constants.INST_DIRECTIONS: if direction == constants.INST_BACKWARD: if mode == constants.TN_MODE: continue # ITN mode ( processed_w_words, processed_s_words, processed_classes, processed_nb_spans, processed_s_span_starts, processed_s_span_ends, ) = ([], [], [], 0, [], []) s_word_idx = 0 for cls, w_word, s_word in zip(classes, w_words, s_words): if s_word == constants.SIL_WORD: continue elif s_word == constants.SELF_WORD: processed_s_words.append(w_word) else: processed_s_words.append(s_word) processed_nb_spans += 1 processed_classes.append(cls) processed_s_span_starts.append(s_word_idx) s_word_idx += len( basic_tokenize(processed_s_words[-1], lang=self.lang)) processed_s_span_ends.append(s_word_idx) processed_w_words.append(w_word) self.span_starts.append(processed_s_span_starts) self.span_ends.append(processed_s_span_ends) self.classes.append(processed_classes) self.nb_spans.append(processed_nb_spans) # Basic tokenization input_words = basic_tokenize(' '.join(processed_s_words), lang) # Update self.directions, self.inputs, self.targets self.directions.append(direction) self.inputs.append(' '.join(input_words)) self.targets.append( processed_w_words ) # is list of lists where inner list contains target tokens (not words) # TN mode elif direction == constants.INST_FORWARD: if mode == constants.ITN_MODE: continue ( processed_w_words, processed_s_words, processed_classes, processed_nb_spans, w_span_starts, w_span_ends, ) = ([], [], [], 0, [], []) w_word_idx = 0 for cls, w_word, s_word in zip(classes, w_words, s_words): # TN forward mode if s_word in constants.SPECIAL_WORDS: processed_s_words.append(w_word) else: processed_s_words.append(s_word) w_span_starts.append(w_word_idx) w_word_idx += len( basic_tokenize(w_word, lang=self.lang)) w_span_ends.append(w_word_idx) processed_nb_spans += 1 processed_classes.append(cls) processed_w_words.append(w_word) self.span_starts.append(w_span_starts) self.span_ends.append(w_span_ends) self.classes.append(processed_classes) self.nb_spans.append(processed_nb_spans) # Basic tokenization input_words = basic_tokenize(' '.join(processed_w_words), lang) # Update self.directions, self.inputs, self.targets self.directions.append(direction) self.inputs.append(' '.join(input_words)) self.targets.append( processed_s_words ) # is list of lists where inner list contains target tokens (not words) self.examples = list( zip( self.directions, self.inputs, self.targets, self.classes, self.nb_spans, self.span_starts, self.span_ends, ))
def evaluate(self, dataset: TextNormalizationTestDataset, batch_size: int, errors_log_fp: str, verbose: bool = True): """ Function for evaluating the performance of the model on a dataset Args: dataset: The dataset to be used for evaluation. batch_size: Batch size to use during inference. You can set it to be 1 (no batching) if you want to measure the running time of the model per individual example (assuming requests are coming to the model one-by-one). errors_log_fp: Path to the file for logging the errors verbose: if true prints and logs various evaluation results Returns: results: A Dict containing the evaluation results (e.g., accuracy, running time) """ results = {} error_f = open(errors_log_fp, 'w+') # Apply the model on the dataset ( all_run_times, all_dirs, all_inputs, all_targets, all_classes, all_nb_spans, all_span_starts, all_span_ends, all_output_spans, ) = ([], [], [], [], [], [], [], [], []) all_tag_preds, all_final_preds = [], [] nb_iters = int(ceil(len(dataset) / batch_size)) for i in tqdm(range(nb_iters)): start_idx = i * batch_size end_idx = (i + 1) * batch_size batch_insts = dataset[start_idx:end_idx] ( batch_dirs, batch_inputs, batch_targets, batch_classes, batch_nb_spans, batch_span_starts, batch_span_ends, ) = zip(*batch_insts) # Inference and Running Time Measurement batch_start_time = perf_counter() batch_tag_preds, batch_output_spans, batch_final_preds = self._infer( batch_inputs, batch_dirs) batch_run_time = (perf_counter() - batch_start_time) * 1000 # milliseconds all_run_times.append(batch_run_time) # Update all_dirs, all_inputs, all_tag_preds, all_final_preds and all_targets all_dirs.extend(batch_dirs) all_inputs.extend(batch_inputs) all_tag_preds.extend(batch_tag_preds) all_final_preds.extend(batch_final_preds) all_targets.extend(batch_targets) all_classes.extend(batch_classes) all_nb_spans.extend(batch_nb_spans) all_span_starts.extend(batch_span_starts) all_span_ends.extend(batch_span_ends) all_output_spans.extend(batch_output_spans) # Metrics tn_error_ctx, itn_error_ctx = 0, 0 for direction in constants.INST_DIRECTIONS: ( cur_dirs, cur_inputs, cur_tag_preds, cur_final_preds, cur_targets, cur_classes, cur_nb_spans, cur_span_starts, cur_span_ends, cur_output_spans, ) = ([], [], [], [], [], [], [], [], [], []) for dir, _input, tag_pred, final_pred, target, cls, nb_spans, span_starts, span_ends, output_spans in zip( all_dirs, all_inputs, all_tag_preds, all_final_preds, all_targets, all_classes, all_nb_spans, all_span_starts, all_span_ends, all_output_spans, ): if dir == direction: cur_dirs.append(dir) cur_inputs.append(_input) cur_tag_preds.append(tag_pred) cur_final_preds.append(final_pred) cur_targets.append(target) cur_classes.append(cls) cur_nb_spans.append(nb_spans) cur_span_starts.append(span_starts) cur_span_ends.append(span_ends) cur_output_spans.append(output_spans) nb_instances = len(cur_final_preds) cur_targets_sent = [" ".join(x) for x in cur_targets] sent_accuracy = TextNormalizationTestDataset.compute_sent_accuracy( cur_final_preds, cur_targets_sent, cur_dirs, self.lang) class_accuracy = TextNormalizationTestDataset.compute_class_accuracy( self.input_preprocessing(list(cur_inputs)), cur_targets, cur_tag_preds, cur_dirs, cur_output_spans, cur_classes, cur_nb_spans, cur_span_starts, cur_span_ends, self.lang, ) if verbose: logging.info( f'\n============ Direction {direction} ============') logging.info(f'Sentence Accuracy: {sent_accuracy}') logging.info(f'nb_instances: {nb_instances}') if not isinstance(class_accuracy, str): log_class_accuracies = "" for key, value in class_accuracy.items(): log_class_accuracies += f"\n\t{key}:\t{value[0]}\t{value[1]}/{value[2]}" else: log_class_accuracies = class_accuracy logging.info(f'class accuracies: {log_class_accuracies}') # Update results results[direction] = { 'sent_accuracy': sent_accuracy, 'nb_instances': nb_instances, "class_accuracy": log_class_accuracies, } # Write errors to log file for _input, tag_pred, final_pred, target in zip( cur_inputs, cur_tag_preds, cur_final_preds, cur_targets_sent): if not TextNormalizationTestDataset.is_same( final_pred, target, direction, self.lang): if direction == constants.INST_BACKWARD: error_f.write('Backward Problem (ITN)\n') itn_error_ctx += 1 elif direction == constants.INST_FORWARD: error_f.write('Forward Problem (TN)\n') tn_error_ctx += 1 formatted_input_str = get_formatted_string( basic_tokenize(_input, lang=self.lang)) formatted_tag_pred_str = get_formatted_string(tag_pred) error_f.write(f'Original Input : {_input}\n') error_f.write(f'Input : {formatted_input_str}\n') error_f.write( f'Predicted Tags : {formatted_tag_pred_str}\n') error_f.write(f'Predicted Str : {final_pred}\n') error_f.write(f'Ground-Truth : {target}\n') error_f.write('\n') results['itn_error_ctx'] = itn_error_ctx results['tn_error_ctx'] = tn_error_ctx # Running Time avg_running_time = np.average(all_run_times) / batch_size # in ms if verbose: logging.info( f'Average running time (normalized by batch size): {avg_running_time} ms' ) results['running_time'] = avg_running_time # Close log file error_f.close() return results
parser.add_argument('--lang', type=str, default=constants.ENGLISH, choices=constants.SUPPORTED_LANGS, help='Language') args = parser.parse_args() # Create the output dir (if not exist) if not isdir(args.output_dir): mkdir(args.output_dir) # Processing train, dev, test = read_google_data(args.data_dir, args.lang) for split, data in zip(constants.SPLIT_NAMES, [train, dev, test]): output_f = open(join(args.output_dir, f'{split}.tsv'), 'w+', encoding='utf-8') for inst in data: cur_classes, cur_tokens, cur_outputs = inst for c, t, o in zip(cur_classes, cur_tokens, cur_outputs): t = ' '.join(basic_tokenize(t, args.lang)) if not o in constants.SPECIAL_WORDS: o_tokens = basic_tokenize(o, args.lang) o_tokens = [ o_tok for o_tok in o_tokens if o_tok != constants.SIL_WORD ] o = ' '.join(o_tokens) output_f.write(f'{c}\t{t}\t{o}\n') output_f.write('<eos>\t<eos>\n')