예제 #1
0
def write_metrics(metrics, config):
    from lidbox.models.keras_utils import experiment_cache_from_config
    metrics_file = os.path.join(experiment_cache_from_config(config),
                                "predictions", "metrics.json")
    os.makedirs(os.path.dirname(metrics_file), exist_ok=True)
    logger.info("Writing evaluated metrics to '%s'", metrics_file)
    with open(metrics_file, "w") as f:
        json.dump(metrics, f)
예제 #2
0
파일: api.py 프로젝트: gaoyiyeah/lidbox
def _metrics_file_from_config(config, dataset_name):
    from lidbox.models.keras_utils import experiment_cache_from_config
    return os.path.join(experiment_cache_from_config(config), "predictions",
                        dataset_name, "metrics.json")
예제 #3
0
def create_dataset(split, labels, init_data, config):
    """
    split:
        Split key.
    labels:
        All labels from all datasets.
    init_data:
        All metadata by split from all datasets.
    config:
        Contents of the lidbox config file, unmodified.
    """
    # Configure steps to create dataset iterator
    steps = [
        # Create a tf.data.Dataset that contains all metadata, e.g. paths from utt2path and labels from utt2label etc.
        Step("initialize", {
            "labels": labels,
            "init_data": init_data
        }),
    ]
    if "post_initialize" in config:
        # "Pre-pre-process" all metadata before any signals are read
        if "file_limit" in config["post_initialize"]:
            steps.append(
                Step(
                    "lambda", {
                        "fn":
                        lambda ds: ds.take(config["post_initialize"][
                            "file_limit"])
                    }))
        if "shuffle_buffer_size" in config["post_initialize"]:
            # Shuffle all files
            steps.append(
                Step(
                    "shuffle", {
                        "buffer_size":
                        config["post_initialize"]["shuffle_buffer_size"]
                    }))
        if "binary_classification" in config["post_initialize"]:
            # Convert all labels to binary classification
            steps.append(
                Step(
                    "convert_to_binary_classification", {
                        "positive_class":
                        config["post_initialize"]["binary_classification"]
                    }))
        if config["post_initialize"].get("check_wav_headers", False):
            steps.append(Step("drop_invalid_wavs", {}))
    if "features" in config and config["features"]["type"] == "kaldi":
        # Features will be imported from Kaldi files, assume no signals should be loaded
        pass
    else:
        # Assume all features will be extracted from signals
        steps.extend([
            # Load signals from all paths
            Step(
                "load_audio", {
                    "num_prefetch":
                    config.get("post_initialize",
                               {"num_prefetched_signals": None
                                })["num_prefetched_signals"]
                }),
            # Drop empty signals
            Step("drop_empty", {})
        ])
    if "pre_process" in config:
        # Pre-processing before feature extraction has been defined in the config file
        if "filters" in config["pre_process"]:
            # Drop unwanted signals
            steps.append(
                Step("apply_filters",
                     {"config": config["pre_process"]["filters"]}))
        if "webrtcvad" in config["pre_process"] or "rms_vad" in config[
                "pre_process"]:
            # Voice activity detection
            if "webrtcvad" in config["pre_process"]:
                # Compute WebRTC VAD decisions
                steps.append(
                    Step("compute_webrtc_vad",
                         config["pre_process"]["webrtcvad"]))
            elif "rms_vad" in config["pre_process"]:
                # Compute VAD decisions by comparing the RMS value of each VAD frame to the mean RMS value over each signal
                steps.append(
                    Step("compute_rms_vad", config["pre_process"]["rms_vad"]))
            steps.extend([
                # Drop non-speech frames using computed decisions
                Step("apply_vad", {}),
                # Some signals might contain only non-speech frames
                Step("drop_empty", {}),
            ])
        if "repeat_too_short_signals" in config["pre_process"]:
            # Repeat all signals until they are of given length
            steps.append(
                Step("repeat_too_short_signals",
                     config["pre_process"]["repeat_too_short_signals"]))
        if "augment" in config["pre_process"]:
            augment_configs = [
                conf for conf in config["pre_process"]["augment"]
                if conf["split"] == split
            ]
            # Apply augmentation only if this dataset split was specified to be augmented
            if augment_configs:
                steps.append(
                    Step("augment_signals",
                         {"augment_configs": augment_configs}))
        if "chunks" in config["pre_process"]:
            # Dividing signals into fixed length chunks
            steps.append(
                Step("create_signal_chunks", config["pre_process"]["chunks"]))
        # TODO not yet supported
        # if "random_chunks" in config["pre_process"]:
        if "cache" in config["pre_process"]:
            steps.extend(
                _get_cache_steps(config["pre_process"]["cache"], split))
    if "features" in config:
        # Load features
        if config["features"]["type"] == "kaldi":
            # Pre-extracted Kaldi features will be used as input
            steps.append(
                # Use the 'kaldi_ark_key' to load contents from an external Kaldi archive file and drop Kaldi metadata
                Step("load_kaldi_data",
                     {"shape": config["features"]["kaldi"]["shape"]}))
        else:
            # Features will be extracted from 'signal' and stored under 'input'
            # Uses GPU by default, can be changed with the 'device' key
            steps.append(
                Step("extract_features", {"config": config["features"]}))
    if "post_process" in config:
        if "filters" in config["post_process"]:
            # Drop unwanted features
            steps.append(
                Step("apply_filters",
                     {"config": config["post_process"]["filters"]}))
        if "chunks" in config["post_process"]:
            # Dividing inputs into fixed length chunks
            steps.append(
                Step("create_input_chunks", config["post_process"]["chunks"]))
        if "normalize" in config["post_process"]:
            steps.append(
                Step("normalize",
                     {"config": config["post_process"]["normalize"]}))
        if "shuffle_buffer_size" in config["post_process"]:
            steps.append(
                Step("shuffle", {
                    "buffer_size":
                    config["post_process"]["shuffle_buffer_size"]
                }))
        if "tensorboard" in config["post_process"]:
            tensorboard_config = {
                "summary_dir":
                os.path.join(experiment_cache_from_config(config),
                             "tensorboard", "dataset", split),
                "config":
                config["post_process"]["tensorboard"]
            }
            # Add some samples to TensorBoard for inspection
            steps.append(Step("consume_to_tensorboard", tensorboard_config))
        if "remap_keys" in config["post_process"]:
            steps.append(
                Step("remap_keys",
                     {"new_keys": config["post_process"]["remap_keys"]}))
        if "cache" in config["post_process"]:
            steps.extend(
                _get_cache_steps(config["post_process"]["cache"], split))
    # TODO convert to binary classification here
    # TODO pre_training config key
    if "experiment" in config:
        # Check this split should be shuffled before training
        for experiment_conf in config["experiment"]["data"].values():
            if experiment_conf[
                    "split"] == split and "shuffle_buffer_size" in experiment_conf:
                steps.append(
                    Step("shuffle", {
                        "buffer_size": experiment_conf["shuffle_buffer_size"]
                    }))
                break
    if "embeddings" in config:
        steps.append(
            Step("extract_embeddings", {"config": config["embeddings"]}))
        if "remap_keys" in config["embeddings"]:
            steps.append(
                Step("remap_keys",
                     {"new_keys": config["embeddings"]["remap_keys"]}))
        if "cache" in config["embeddings"]:
            steps.extend(_get_cache_steps(config["embeddings"]["cache"],
                                          split))
    return steps
예제 #4
0
def evaluate_test_set(split2ds, split2meta, labels, config):
    from lidbox.dataset.steps import as_supervised, initialize
    from lidbox.models.keras_utils import best_model_checkpoint_from_config, experiment_cache_from_config
    test_conf = config["experiment"]["data"]["test"]
    test_ds = (split2ds[test_conf["split"]].batch(
        test_conf["batch_size"]).apply(as_supervised))
    predictions = None
    if "user_script" in config:
        user_script = load_user_script_as_module(config["user_script"])
        if hasattr(user_script, "predict"):
            logger.info(
                "User script has defined a 'predict' function, will use it")
            predictions = user_script.predict(test_ds, config)
            if predictions is None:
                logger.error(
                    "Function 'predict' in the user script '%s' did not return predictions",
                    config["user_script"])
                return
    if predictions is None:
        logger.info(
            "User script has not defined a 'predict' function, will use default approach"
        )
        keras_wrapper = KerasWrapper.from_config(config)
        logger.info("Model initialized:\n%s", str(keras_wrapper))
        best_checkpoint = best_model_checkpoint_from_config(config)
        logger.info("Loading weights from checkpoint file '%s'",
                    best_checkpoint)
        keras_wrapper.load_weights(best_checkpoint)
        logger.info("Starting prediction with model '%s'",
                    keras_wrapper.model_key)
        predictions = keras_wrapper.keras_model.predict(test_ds)
    logger.info(
        "Model returned predictions of shape %s, now gathering all test set ids",
        repr(predictions.shape))
    test_ids = [
        x["id"].decode("utf-8")
        for x in split2ds[test_conf["split"]].as_numpy_iterator()
    ]
    utt2prediction = sorted(zip(test_ids, predictions), key=lambda t: t[0])
    del test_ids
    has_chunks = False
    if "chunks" in config.get("pre_process", {}):
        logger.info(
            "Original signals were divided into chunks, merging chunk scores by averaging"
        )
        has_chunks = True
    if "chunks" in config.get("post_process", {}):
        logger.info(
            "Extracted features were divided into chunks, merging chunk scores by averaging"
        )
        has_chunks = True
    if has_chunks:
        utt2prediction = group_chunk_predictions_by_parent_id(utt2prediction)
        predictions = np.array([p for _, p in utt2prediction])
    # Collect targets from the test set iterator
    test_meta_ds = initialize(None, labels, split2meta[test_conf["split"]])
    utt2target = {
        x["id"].decode("utf-8"): x["target"]
        for x in test_meta_ds.as_numpy_iterator()
    }
    missed_utterances = set(utt2target.keys()) - set(
        u for u, _ in utt2prediction)
    min_score = np.amin(predictions)
    max_score = np.amax(predictions)
    if missed_utterances:
        logger.info(
            "%d test samples had no predictions and worst-case scores %.3f will be generated for them for every label",
            len(missed_utterances), min_score)
        utt2prediction.extend([(utt, np.array([min_score for _ in labels]))
                               for utt in sorted(missed_utterances)])
    scores_file = os.path.join(experiment_cache_from_config(config),
                               "predictions", "scores")
    os.makedirs(os.path.dirname(scores_file), exist_ok=True)
    logger.info("Writing predicted scores to '%s'", scores_file)
    if os.path.exists(scores_file):
        logger.warning("Overwriting existing '%s'", scores_file)
    with open(scores_file, "w") as scores_f:
        print_predictions(utt2prediction, labels, file=scores_f)
    metric_results = []
    # Ensure true labels are always in the same order as in predictions
    predictions = np.array([p for _, p in utt2prediction])
    true_labels_sparse = np.array([utt2target[u] for u, _ in utt2prediction])
    pred_labels_sparse = np.argmax(predictions, axis=1)
    logger.info(
        "Evaluating metrics on true labels of shape %s and predicted labels of shape %s",
        true_labels_sparse.shape, pred_labels_sparse.shape)
    for metric in test_conf["evaluate_metrics"]:
        result = None
        if metric["name"].endswith("average_detection_cost"):
            logger.info("Evaluating minimum average detection cost")
            thresholds = np.linspace(min_score, max_score,
                                     metric.get("num_thresholds", 200))
            if metric["name"].startswith("sparse_"):
                cavg = lidbox.metrics.SparseAverageDetectionCost(
                    len(labels), thresholds)
                cavg.update_state(np.expand_dims(true_labels_sparse, -1),
                                  predictions)
            else:
                cavg = lidbox.metrics.AverageDetectionCost(
                    len(labels), thresholds)
                cavg.update_state(true_labels, predictions)
            result = float(cavg.result().numpy())
            logger.info("%s: %.6f", metric["name"], result)
        elif metric["name"].endswith("average_equal_error_rate"):
            #TODO sparse EER, generate one-hot true_labels
            logger.info("Evaluating average equal error rate")
            eer = np.zeros(len(labels))
            for l, label in enumerate(labels):
                if label not in all_testset_labels:
                    eer[l] = 0
                    continue
                # https://stackoverflow.com/a/46026962
                fpr, tpr, _ = sklearn.metrics.roc_curve(
                    true_labels[:, l], predictions[:, l])
                fnr = 1 - tpr
                eer[l] = fpr[np.nanargmin(np.absolute(fnr - fpr))]
            result = {
                "avg": float(eer.mean()),
                "by_label":
                {label: float(eer[l])
                 for l, label in enumerate(labels)}
            }
            logger.info("%s: %s", metric["name"],
                        lidbox.yaml_pprint(result, to_string=True))
        elif metric["name"] == "average_f1_score":
            logger.info("Evaluating average F1 score")
            f1 = sklearn.metrics.f1_score(true_labels_sparse,
                                          pred_labels_sparse,
                                          labels=list(range(len(labels))),
                                          average="weighted")
            result = {"avg": float(f1)}
            logger.info("%s: %.6f", metric["name"], f1)
        elif metric["name"] == "sklearn_classification_report":
            logger.info("Generating full sklearn classification report")
            result = sklearn.metrics.classification_report(
                true_labels_sparse,
                pred_labels_sparse,
                labels=list(range(len(labels))),
                target_names=labels,
                output_dict=True,
                zero_division=0)
            logger.info("%s:\n%s", metric["name"],
                        lidbox.yaml_pprint(result, left_pad=2, to_string=True))
        elif metric["name"] == "confusion_matrix":
            logger.info("Generating confusion matrix")
            result = sklearn.metrics.confusion_matrix(true_labels_sparse,
                                                      pred_labels_sparse)
            logger.info("%s:\n%s", metric["name"],
                        format_confusion_matrix(result, labels))
            result = result.tolist()
        else:
            logger.error("Cannot evaluate unknown metric '%s'", metric["name"])
        metric_results.append({"name": metric["name"], "result": result})
    return metric_results