def print_class_based_stats(class2stats):
    """ Print statistics of class-based evaluation results """
    for class_name in class2stats:
        correct_count = np.sum(class2stats[class_name])
        total_count = len(class2stats[class_name])
        class_acc = np.average(class2stats[class_name])
        class_acc = str(round(class_acc,
                              3)) + f'% ({correct_count}/{total_count})'
        formatted_str = get_formatted_string((class_name, class_acc),
                                             str_max_len=20)
        print(formatted_str)
    print()
    def multi_validation_epoch_end(self,
                                   outputs: List,
                                   dataloader_idx=0,
                                   split="val"):
        """
        Called at the end of validation to aggregate outputs.

        Args:
            outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x[f'{split}_loss'] for x in outputs]).mean()

        # create a dictionary to store all the results
        results = {}
        directions = [constants.TN_MODE, constants.ITN_MODE
                      ] if self.mode == constants.JOINT_MODE else [self.mode]
        for class_name in self._val_class_to_id[dataloader_idx]:
            for direction in directions:
                results[f"correct_{class_name}_{direction}"] = 0
                results[f"total_{class_name}_{direction}"] = 0

        for key in results:
            count = [x[key] for x in outputs if key in x]
            count = torch.stack(count).sum(
            ) if len(count) > 0 else torch.tensor(0).to(self.device)
            results[key] = count

        all_results = defaultdict(list)

        if torch.distributed.is_initialized():
            world_size = torch.distributed.get_world_size()
            for ind in range(world_size):
                for key, v in results.items():
                    all_results[key].append(torch.empty_like(v))
            for key, v in results.items():
                torch.distributed.all_gather(all_results[key], v)
        else:
            for key, v in results.items():
                all_results[key].append(v)

        if not torch.distributed.is_initialized(
        ) or torch.distributed.get_rank() == 0:
            if split == "test":
                val_name = self._test_names[dataloader_idx].upper()
            else:
                val_name = self._validation_names[dataloader_idx].upper()
            final_results = defaultdict(int)
            for key, v in all_results.items():
                for _v in v:
                    final_results[key] += _v.item()

            accuracies = defaultdict(dict)
            for key, value in final_results.items():
                if "total_" in key:
                    _, class_name, mode = key.split('_')
                    correct = final_results[f"correct_{class_name}_{mode}"]
                    if value == 0:
                        accuracies[mode][class_name] = (0, correct, value)
                    else:
                        acc = round(correct / value * 100, 3)
                        accuracies[mode][class_name] = (acc, correct, value)

            for mode, values in accuracies.items():
                report = f"Accuracy {mode.upper()} task {val_name}:\n"
                report += '\n'.join([
                    get_formatted_string(
                        (class_name, f'{v[0]}% ({v[1]}/{v[2]})'),
                        str_max_len=24) for class_name, v in values.items()
                ])
                # calculate average across all classes
                all_total = 0
                all_correct = 0
                for _, class_values in values.items():
                    _, correct, total = class_values
                    all_correct += correct
                    all_total += total
                all_acc = round(
                    (all_correct / all_total) * 100, 3) if all_total > 0 else 0
                report += '\n' + get_formatted_string(
                    ('AVG', f'{all_acc}% ({all_correct}/{all_total})'),
                    str_max_len=24)
                logging.info(report)
                accuracies[mode]['AVG'] = [all_acc]

        self.log(f'{split}_loss', avg_loss)
        if self.trainer.is_global_zero:
            for mode in accuracies:
                for class_name, values in accuracies[mode].items():
                    self.log(
                        f'{val_name}_{mode.upper()}_acc_{class_name.upper()}',
                        values[0],
                        rank_zero_only=True)
        return {
            f'{split}_loss': avg_loss,
        }
Exemple #3
0
    def evaluate(self,
                 dataset: TextNormalizationTestDataset,
                 batch_size: int,
                 errors_log_fp: str,
                 verbose: bool = True):
        """ Function for evaluating the performance of the model on a dataset

        Args:
            dataset: The dataset to be used for evaluation.
            batch_size: Batch size to use during inference. You can set it to be 1
                (no batching) if you want to measure the running time of the model
                per individual example (assuming requests are coming to the model one-by-one).
            errors_log_fp: Path to the file for logging the errors
            verbose: if true prints and logs various evaluation results

        Returns:
            results: A Dict containing the evaluation results (e.g., accuracy, running time)
        """
        results = {}
        error_f = open(errors_log_fp, 'w+')

        # Apply the model on the dataset
        all_run_times, all_dirs, all_inputs = [], [], []
        all_tag_preds, all_final_preds, all_targets = [], [], []
        nb_iters = int(ceil(len(dataset) / batch_size))
        for i in tqdm(range(nb_iters)):
            start_idx = i * batch_size
            end_idx = (i + 1) * batch_size
            batch_insts = dataset[start_idx:end_idx]
            batch_dirs, batch_inputs, batch_targets = zip(*batch_insts)
            # Inference and Running Time Measurement
            batch_start_time = perf_counter()
            batch_tag_preds, _, batch_final_preds = self._infer(
                batch_inputs, batch_dirs)
            batch_run_time = (perf_counter() -
                              batch_start_time) * 1000  # milliseconds
            all_run_times.append(batch_run_time)
            # Update all_dirs, all_inputs, all_tag_preds, all_final_preds and all_targets
            all_dirs.extend(batch_dirs)
            all_inputs.extend(batch_inputs)
            all_tag_preds.extend(batch_tag_preds)
            all_final_preds.extend(batch_final_preds)
            all_targets.extend(batch_targets)

        # Metrics
        tn_error_ctx, itn_error_ctx = 0, 0
        for direction in constants.INST_DIRECTIONS:
            cur_dirs, cur_inputs, cur_tag_preds, cur_final_preds, cur_targets = [], [], [], [], []
            for dir, _input, tag_pred, final_pred, target in zip(
                    all_dirs, all_inputs, all_tag_preds, all_final_preds,
                    all_targets):
                if dir == direction:
                    cur_dirs.append(dir)
                    cur_inputs.append(_input)
                    cur_tag_preds.append(tag_pred)
                    cur_final_preds.append(final_pred)
                    cur_targets.append(target)
            nb_instances = len(cur_final_preds)
            sent_accuracy = TextNormalizationTestDataset.compute_sent_accuracy(
                cur_final_preds, cur_targets, cur_dirs)
            if verbose:
                logging.info(
                    f'\n============ Direction {direction} ============')
                logging.info(f'Sentence Accuracy: {sent_accuracy}')
                logging.info(f'nb_instances: {nb_instances}')
            # Update results
            results[direction] = {
                'sent_accuracy': sent_accuracy,
                'nb_instances': nb_instances
            }
            # Write errors to log file
            for _input, tag_pred, final_pred, target in zip(
                    cur_inputs, cur_tag_preds, cur_final_preds, cur_targets):
                if not TextNormalizationTestDataset.is_same(
                        final_pred, target, direction):
                    if direction == constants.INST_BACKWARD:
                        error_f.write('Backward Problem (ITN)\n')
                        itn_error_ctx += 1
                    elif direction == constants.INST_FORWARD:
                        error_f.write('Forward Problem (TN)\n')
                        tn_error_ctx += 1
                    formatted_input_str = get_formatted_string(
                        _input.split(' '))
                    formatted_tag_pred_str = get_formatted_string(tag_pred)
                    error_f.write(f'Original Input : {_input}\n')
                    error_f.write(f'Input          : {formatted_input_str}\n')
                    error_f.write(
                        f'Predicted Tags : {formatted_tag_pred_str}\n')
                    error_f.write(f'Predicted      : {final_pred}\n')
                    error_f.write(f'Ground-Truth   : {target}\n')
                    error_f.write('\n')
            results['itn_error_ctx'] = itn_error_ctx
            results['tn_error_ctx'] = tn_error_ctx

        # Running Time
        avg_running_time = np.average(all_run_times) / batch_size  # in ms
        if verbose:
            logging.info(
                f'Average running time (normalized by batch size): {avg_running_time} ms'
            )
        results['running_time'] = avg_running_time

        # Close log file
        error_f.close()

        return results
Exemple #4
0
    def evaluate(self,
                 dataset: TextNormalizationTestDataset,
                 batch_size: int,
                 errors_log_fp: str,
                 verbose: bool = True):
        """ Function for evaluating the performance of the model on a dataset

        Args:
            dataset: The dataset to be used for evaluation.
            batch_size: Batch size to use during inference. You can set it to be 1
                (no batching) if you want to measure the running time of the model
                per individual example (assuming requests are coming to the model one-by-one).
            errors_log_fp: Path to the file for logging the errors
            verbose: if true prints and logs various evaluation results

        Returns:
            results: A Dict containing the evaluation results (e.g., accuracy, running time)
        """
        results = {}
        error_f = open(errors_log_fp, 'w+')

        # Apply the model on the dataset
        (
            all_run_times,
            all_dirs,
            all_inputs,
            all_targets,
            all_classes,
            all_nb_spans,
            all_span_starts,
            all_span_ends,
            all_output_spans,
        ) = ([], [], [], [], [], [], [], [], [])
        all_tag_preds, all_final_preds = [], []
        nb_iters = int(ceil(len(dataset) / batch_size))
        for i in tqdm(range(nb_iters)):
            start_idx = i * batch_size
            end_idx = (i + 1) * batch_size
            batch_insts = dataset[start_idx:end_idx]
            (
                batch_dirs,
                batch_inputs,
                batch_targets,
                batch_classes,
                batch_nb_spans,
                batch_span_starts,
                batch_span_ends,
            ) = zip(*batch_insts)
            # Inference and Running Time Measurement
            batch_start_time = perf_counter()
            batch_tag_preds, batch_output_spans, batch_final_preds = self._infer(
                batch_inputs, batch_dirs)
            batch_run_time = (perf_counter() -
                              batch_start_time) * 1000  # milliseconds
            all_run_times.append(batch_run_time)
            # Update all_dirs, all_inputs, all_tag_preds, all_final_preds and all_targets
            all_dirs.extend(batch_dirs)
            all_inputs.extend(batch_inputs)
            all_tag_preds.extend(batch_tag_preds)
            all_final_preds.extend(batch_final_preds)
            all_targets.extend(batch_targets)
            all_classes.extend(batch_classes)
            all_nb_spans.extend(batch_nb_spans)
            all_span_starts.extend(batch_span_starts)
            all_span_ends.extend(batch_span_ends)
            all_output_spans.extend(batch_output_spans)

        # Metrics
        tn_error_ctx, itn_error_ctx = 0, 0
        for direction in constants.INST_DIRECTIONS:
            (
                cur_dirs,
                cur_inputs,
                cur_tag_preds,
                cur_final_preds,
                cur_targets,
                cur_classes,
                cur_nb_spans,
                cur_span_starts,
                cur_span_ends,
                cur_output_spans,
            ) = ([], [], [], [], [], [], [], [], [], [])
            for dir, _input, tag_pred, final_pred, target, cls, nb_spans, span_starts, span_ends, output_spans in zip(
                    all_dirs,
                    all_inputs,
                    all_tag_preds,
                    all_final_preds,
                    all_targets,
                    all_classes,
                    all_nb_spans,
                    all_span_starts,
                    all_span_ends,
                    all_output_spans,
            ):
                if dir == direction:
                    cur_dirs.append(dir)
                    cur_inputs.append(_input)
                    cur_tag_preds.append(tag_pred)
                    cur_final_preds.append(final_pred)
                    cur_targets.append(target)
                    cur_classes.append(cls)
                    cur_nb_spans.append(nb_spans)
                    cur_span_starts.append(span_starts)
                    cur_span_ends.append(span_ends)
                    cur_output_spans.append(output_spans)
            nb_instances = len(cur_final_preds)
            cur_targets_sent = [" ".join(x) for x in cur_targets]
            sent_accuracy = TextNormalizationTestDataset.compute_sent_accuracy(
                cur_final_preds, cur_targets_sent, cur_dirs, self.lang)
            class_accuracy = TextNormalizationTestDataset.compute_class_accuracy(
                [basic_tokenize(x, lang=self.lang) for x in cur_inputs],
                cur_targets,
                cur_tag_preds,
                cur_dirs,
                cur_output_spans,
                cur_classes,
                cur_nb_spans,
                cur_span_starts,
                cur_span_ends,
                self.lang,
            )
            if verbose:
                logging.info(
                    f'\n============ Direction {direction} ============')
                logging.info(f'Sentence Accuracy: {sent_accuracy}')
                logging.info(f'nb_instances: {nb_instances}')
                if not isinstance(class_accuracy, str):
                    log_class_accuracies = ""
                    for key, value in class_accuracy.items():
                        log_class_accuracies += f"\n\t{key}:\t{value[0]}\t{value[1]}/{value[2]}"
                else:
                    log_class_accuracies = class_accuracy
                logging.info(f'class accuracies: {log_class_accuracies}')
            # Update results
            results[direction] = {
                'sent_accuracy': sent_accuracy,
                'nb_instances': nb_instances,
                "class_accuracy": log_class_accuracies,
            }
            # Write errors to log file
            for _input, tag_pred, final_pred, target, classes in zip(
                    cur_inputs, cur_tag_preds, cur_final_preds,
                    cur_targets_sent, cur_classes):
                if not TextNormalizationTestDataset.is_same(
                        final_pred, target, direction, self.lang):
                    if direction == constants.INST_BACKWARD:
                        error_f.write('Backward Problem (ITN)\n')
                        itn_error_ctx += 1
                    elif direction == constants.INST_FORWARD:
                        error_f.write('Forward Problem (TN)\n')
                        tn_error_ctx += 1
                    formatted_input_str = get_formatted_string(
                        basic_tokenize(_input, lang=self.lang))
                    formatted_tag_pred_str = get_formatted_string(tag_pred)
                    class_str = " ".join(classes)
                    error_f.write(f'Original Input : {_input}\n')
                    error_f.write(f'Input          : {formatted_input_str}\n')
                    error_f.write(
                        f'Predicted Tags : {formatted_tag_pred_str}\n')
                    error_f.write(f'Ground Classes : {class_str}\n')
                    error_f.write(f'Predicted Str  : {final_pred}\n')
                    error_f.write(f'Ground-Truth   : {target}\n')
                    error_f.write('\n')
            results['itn_error_ctx'] = itn_error_ctx
            results['tn_error_ctx'] = tn_error_ctx

        # Running Time
        avg_running_time = np.average(all_run_times) / batch_size  # in ms
        if verbose:
            logging.info(
                f'Average running time (normalized by batch size): {avg_running_time} ms'
            )
        results['running_time'] = avg_running_time

        # Close log file
        error_f.close()
        logging.info(f'Errors are saved at {errors_log_fp}.')
        return results