Ejemplo n.º 1
0
    def search_best_median(self, spans, best_th=None, target="Event"):
        best_span = {k: 1 for k in self.labels}
        best_f1 = {k: 0.0 for k in self.labels}

        for span in spans:
            logging.info(f"median filter span: {span}")
            post_process_fn = [
                functools.partial(median_filt_1d, filt_span=span)
            ]
            if best_th is not None:
                prediction_df = self.get_prediction_dataframe(
                    post_processing=post_process_fn,
                    threshold=list(best_th.values()),
                    binarization_type="class_threshold",
                )
            else:
                prediction_df = self.get_prediction_dataframe(
                    post_processing=post_process_fn)
            events_metric, segments_metric, psds_m_f1 = compute_metrics(
                prediction_df, self.validation_df, self.durations_validation)
            for i, label in enumerate(self.labels):
                f1 = events_metric.class_wise_f_measure(
                    event_label=label)["f_measure"]
                # if target == 'Event':
                #     f1 = valid_events_metric.class_wise_f_measure(event_label=label)['f_measure']
                # elif target == 'Frame':
                #     f1 = frame_measure[i].calc_f1()[2]
                # else:
                #     raise NotImplementedError
                if f1 > best_f1[label]:
                    best_span[label] = span
                    best_f1[label] = f1

        post_process_fn = [
            functools.partial(median_filt_1d,
                              filt_span=list(best_span.values()))
        ]
        if best_th is not None:
            prediction_df = self.get_prediction_dataframe(
                post_processing=post_process_fn,
                threshold=len(best_th.values()),
                binarization_type="class_threshold",
            )
        else:
            prediction_df = self.get_prediction_dataframe(
                post_processing=post_process_fn)
        # Compute evaluation metrics
        events_metric, segments_metric, psds_m_f1 = compute_metrics(
            prediction_df, self.validation_df, self.durations_validation)
        macro_f1_event = events_metric.results_class_wise_average_metrics(
        )["f_measure"]["f_measure"]
        macro_f1_segment = segments_metric.results_class_wise_average_metrics(
        )["f_measure"]["f_measure"]

        logging.info(f"best_span: {best_span}")
        logging.info(f"best_f1: {best_f1}")
        return best_span, best_f1
Ejemplo n.º 2
0
    def search_best_threshold(self, step, target="Event"):
        assert 0 < step < 1.0
        assert target in ["Event", "Frame"]
        best_th = {k: 0.0 for k in self.labels}
        best_f1 = {k: 0.0 for k in self.labels}

        for th in np.arange(step, 1.0, step):
            logging.info(f"threshold: {th}")
            prediction_df = self.get_prediction_dataframe(
                threshold=th,
                binarization_type="global_threshold",
                save_predictions=None,
            )
            events_metric, segments_metric, psds_m_f1 = compute_metrics(
                prediction_df, self.validation_df, self.durations_validation)

            for i, label in enumerate(self.labels):
                f1 = events_metric.class_wise_f_measure(
                    event_label=label)["f_measure"]
                # if target == 'Event':
                #     f1 = valid_events_metric.class_wise_f_measure(event_label=label)['f_measure']
                # elif target == 'Frame':
                #     f1 = frame_measure[i].calc_f1()[2]
                # else:
                #     raise NotImplementedError
                if f1 > best_f1[label]:
                    best_th[label] = th
                    best_f1[label] = f1

        thres_list = [0.5] * len(self.labels)
        for i, label in enumerate(self.labels):
            thres_list[i] = best_th[label]

        prediction_df = self.get_prediction_dataframe(
            post_processing=None,
            threshold=thres_list,
            binarization_type="class_threshold",
        )

        # Compute evaluation metrics
        events_metric, segments_metric, psds_m_f1 = compute_metrics(
            prediction_df, self.validation_df, self.durations_validation)
        macro_f1_event = events_metric.results_class_wise_average_metrics(
        )["f_measure"]["f_measure"]
        macro_f1_segment = segments_metric.results_class_wise_average_metrics(
        )["f_measure"]["f_measure"]

        logging.info(
            f"Event-based F1:{macro_f1_event * 100:.4}\tSegment-based F1:{macro_f1_event * 100:.4}"
        )
        logging.info(f"best_th: {best_th}")
        logging.info(f"best_f1: {best_f1}")
        return best_th, best_f1
Ejemplo n.º 3
0
    def show_best(self, pp_params, save_predictions=None):
        # Set applying post-processing functions
        post_processing_fn = []
        if "threshold" in pp_params.keys():
            threshold = list(pp_params["threshold"].values())
            binarization_type = "class_threshold"
        else:
            threshold = 0.5
            binarization_type = "global_threshold"
        if "median_filtering" in pp_params.keys():
            filt_span = list(pp_params["median_filtering"].values())
            post_processing_fn.append(
                functools.partial(median_filt_1d, filt_span=filt_span))
        if "fill_up_gap" in pp_params.keys():
            accept_gap = list(pp_params["fill_up_gap"].values())
            post_processing_fn.append(
                functools.partial(fill_up_gap, accept_gap=accept_gap))
        if len(post_processing_fn) == 0:
            post_processing_fn = None

        prediction_df = self.get_prediction_dataframe(
            post_processing=post_processing_fn,
            threshold=threshold,
            binarization_type=binarization_type,
        )

        # Compute evaluation metrics
        events_metric, segments_metric, psds_m_f1 = compute_metrics(
            prediction_df, self.validation_df, self.durations_validation)
        macro_f1_event = events_metric.results_class_wise_average_metrics(
        )["f_measure"]["f_measure"]
        macro_f1_segment = segments_metric.results_class_wise_average_metrics(
        )["f_measure"]["f_measure"]

        logging.info(f"Event-based macro F1: {macro_f1_event}")
        logging.info(f"Segment-based macro F1: {macro_f1_segment}")
Ejemplo n.º 4
0
                                   mask_weak=weak_mask,
                                   mask_strong=strong_mask,
                                   adjust_lr=cfg.adjust_lr)

        # Validation
        crnn = crnn.eval()
        logger.info("\n ### Valid synthetic metric ### \n")
        predictions = get_predictions(crnn,
                                      valid_synth_loader,
                                      many_hot_encoder.decode_strong,
                                      pooling_time_ratio,
                                      median_window=median_window,
                                      save_predictions=None)
        # Validation with synthetic data (dropping feature_filename for psds)
        valid_synth = dfs["valid_synthetic"].drop("feature_filename", axis=1)
        valid_synth_f1, psds_m_f1 = compute_metrics(predictions, valid_synth,
                                                    durations_synth)

        # ---------------
        # Save the trainable PCEN parameters (if any)
        if crnn.trainable_pcen is not None:
            pcen_parameters.append({
                key: torch.tensor(value)
                for key, value in crnn.trainable_pcen.state_dict().items()
            })
        # print("PCEN PARAMS:", pcen_parameters)
        torch.save(pcen_parameters, 'params/best_4layers_0688_opt_2.p')
        # ---------------

        # Save meters
        meters_file.append(meters.values())
        np.save('meters/best_4layers_0688_opt_2', meters_file)
Ejemplo n.º 5
0
                    compute_log=False)

    gt_df_feat = dataset.initialize_and_get_df(f_args.groundtruth_tsv,
                                               gt_audio_dir,
                                               nb_files=f_args.nb_files)
    params = _load_state_vars(expe_state, gt_df_feat, median_window)

    # Preds with only one value
    single_predictions = get_predictions(
        params["model"],
        params["dataloader"],
        params["many_hot_encoder"].decode_strong,
        params["pooling_time_ratio"],
        median_window=params["median_window"],
        save_predictions=f_args.save_predictions_path)
    compute_metrics(single_predictions, groundtruth, durations)

    # ##########
    # Optional but recommended
    # ##########
    # Compute psds scores with multiple thresholds (more accurate). n_thresholds could be increased.
    n_thresholds = 50
    # Example of 5 thresholds: 0.1, 0.3, 0.5, 0.7, 0.9
    list_thresholds = np.arange(1 / (n_thresholds * 2), 1, 1 / n_thresholds)
    pred_ss_thresh = get_predictions(
        params["model"],
        params["dataloader"],
        params["many_hot_encoder"].decode_strong,
        params["pooling_time_ratio"],
        thresholds=list_thresholds,
        median_window=params["median_window"],
Ejemplo n.º 6
0
if __name__ == '__main__':
    groundtruth_path = "../dataset/metadata/validation/validation.tsv"
    durations_path = "../dataset/metadata/validation/validation_durations.tsv"
    # If durations do not exists, audio dir is needed
    groundtruth_audio_path = "../dataset/audio/validation"
    base_prediction_path = "stored_data/MeanTeacher_with_synthetic/predictions/baseline_validation"

    groundtruth = pd.read_csv(groundtruth_path, sep="\t")
    if osp.exists(durations_path):
        meta_dur_df = pd.read_csv(durations_path, sep='\t')
    else:
        meta_dur_df = generate_tsv_wav_durations(groundtruth_audio_path,
                                                 durations_path)

    # Evaluate a single prediction
    single_predictions = pd.read_csv(base_prediction_path + ".tsv", sep="\t")
    compute_metrics(single_predictions, groundtruth, meta_dur_df)

    # Evaluate predictions with multiple thresholds (better). Need a list of predictions.
    prediction_list_path = glob.glob(osp.join(base_prediction_path, "*.tsv"))
    list_predictions = []
    for fname in prediction_list_path:
        pred_df = pd.read_csv(fname, sep="\t")
        list_predictions.append(pred_df)
    psds = compute_psds_from_operating_points(list_predictions, groundtruth,
                                              meta_dur_df)
    psds_score(psds,
               filename_roc_curves=osp.join(base_prediction_path,
                                            "figures/psds_roc.png"))
Ejemplo n.º 7
0
    def validation(self, post_processing=None) -> None:
        self.strong_losses.reset()
        self.weak_losses.reset()

        prediction_df = pd.DataFrame()
        threshold = self.options.threshold
        decoder = self.options.decoder
        binarization_type = self.options.binarization_type
        post_processing = None

        ptr = self.options.pooling_time_ratio

        # Frame level measure
        frame_measure = [
            ConfusionMatrix() for i in range(len(self.options.classes))
        ]
        tag_measure = ConfusionMatrix()

        self.model.eval()

        for (data, target, data_ids) in self.valid_loader:
            data, target = data.to(self.device), target.to(self.device)
            predicts = self.model(data)

            # compute classification loss
            strong_class_loss = self.classification_criterion(
                predicts["strong"], target)
            weak_class_loss = self.classification_criterion(
                predicts["weak"],
                target.max(dim=1)[0])
            self.strong_losses.update(strong_class_loss.item())
            self.weak_losses.update(weak_class_loss.item())

            predicts["strong"] = torch.sigmoid(
                predicts["strong"]).cpu().data.numpy()
            predicts["weak"] = torch.sigmoid(
                predicts["weak"]).cpu().data.numpy()

            if binarization_type == "class_threshold":
                for i in range(predicts["strong"].shape[0]):
                    predicts["strong"][i] = ProbabilityEncoder().binarization(
                        predicts["strong"][i],
                        binarization_type=binarization_type,
                        threshold=threshold,
                        time_axis=0,
                    )
            else:
                predicts["strong"] = ProbabilityEncoder().binarization(
                    predicts["strong"],
                    binarization_type=binarization_type,
                    threshold=threshold,
                )
                predicts["weak"] = ProbabilityEncoder().binarization(
                    predicts["weak"],
                    binarization_type=binarization_type,
                    threshold=threshold,
                )

            # For debug, frame level measure
            for i in range(len(predicts["strong"])):
                target_np = target.cpu().numpy()
                tn, fp, fn, tp = confusion_matrix(target_np[i].max(axis=0),
                                                  predicts["weak"][i],
                                                  labels=[0, 1]).ravel()
                tag_measure.add_cf(tn, fp, fn, tp)
                for j in range(len(self.options.classes)):
                    tn, fp, fn, tp = confusion_matrix(target_np[i][:, j],
                                                      predicts["strong"][i][:,
                                                                            j],
                                                      labels=[0, 1]).ravel()
                    frame_measure[j].add_cf(tn, fp, fn, tp)

            if post_processing is not None:
                for i in range(predicts["strong"].shape[0]):
                    for post_process_fn in post_processing:
                        predicts["strong"][i] = post_process_fn(
                            predicts["strong"][i])

            for pred, data_id in zip(predicts["strong"], data_ids):
                pred = decoder(pred)
                pred = pd.DataFrame(pred,
                                    columns=["event_label", "onset", "offset"])

                # Put them in seconds
                pred.loc[:, ["onset", "offset"]] *= ptr / (
                    self.options.sample_rate / self.options.hop_size)
                pred.loc[:,
                         ["onset", "offset"]] = pred[["onset", "offset"]].clip(
                             0, self.options.max_len_seconds)

                pred["filename"] = data_id
                prediction_df = prediction_df.append(pred)

        else:
            # save predictions
            prediction_df.to_csv(
                self.exp_name / "predictions" /
                f"{self.forward_count}th_iterations.csv",
                index=False,
                sep="\t",
                float_format="%.3f",
            )

            # Compute evaluation metrics
            events_metric, segments_metric, psds_m_f1 = compute_metrics(
                prediction_df,
                self.options.validation_df,
                self.options.durations_validation,
            )
            macro_f1_event = events_metric.results_class_wise_average_metrics(
            )["f_measure"]["f_measure"]
            macro_f1_segment = segments_metric.results_class_wise_average_metrics(
            )["f_measure"]["f_measure"]

            # Compute frame level macro f1 score
            ave_precision = 0
            ave_recall = 0
            macro_f1 = 0
            for i in range(len(self.options.classes)):
                ave_precision_, ave_recall_, macro_f1_ = frame_measure[
                    i].calc_f1()
                ave_precision += ave_precision_
                ave_recall += ave_recall_
                macro_f1 += macro_f1_
            ave_precision /= len(self.options.classes)
            ave_recall /= len(self.options.classes)
            macro_f1 /= len(self.options.classes)
            weak_f1 = tag_measure.calc_f1()[2]

            metrics = {
                "valid_strong_loss": self.strong_losses.avg,
                "valid_weak_loss": self.weak_losses.avg,
                "event_m_f1": macro_f1_event,
                "segment_m_f1": macro_f1_segment,
                "psds_m_f1": psds_m_f1,
                "frame_level_precision": ave_precision,
                "frame_level_recall": ave_recall,
                "frame_level_macro_f1": macro_f1,
                "weak_f1": weak_f1,
            }

            wandb.log(metrics, step=self.forward_count)

        return metrics
Ejemplo n.º 8
0
        PS_model.train()
        MS_model, MS_ema, PS_model = to_cuda_if_available(MS_model, MS_ema, PS_model)

        loss_value = train(training_loader, MS_model, PS_model, ms_optim, ps_optim, epoch,
                           ema_model=MS_ema, mask_weak=weak_mask, mask_strong=strong_mask, adjust_lr=cfg.adjust_lr)

        # Validation
        ema = MS_ema.eval()
        MS_m = MS_model.eval()
        PS_m = PS_model.eval()
        logger.info("\n ### Valid synthetic metric ### \n")
        predictions = get_predictions(PS_m, ema, MS_m, valid_synth_loader, many_hot_encoder.decode_strong, pooling_time_ratio,
                                      median_window=median_window, save_predictions=None)
        # Validation with synthetic data (dropping feature_filename for psds)
        valid_synth = dfs["validation"].drop("feature_filename", axis=1)
        valid_synth_f1, psds_m_f1 = compute_metrics(predictions, valid_synth, durations_validation)

        # Update state
        state['MS']['state_dict'] = MS_model.state_dict()
        state['PS']['state_dict'] = PS_model.state_dict()
        state['model_ema']['state_dict'] = MS_ema.state_dict()
        state['ms_optimizer']['state_dict'] = ms_optim.state_dict()
        state['ps_optimizer']['state_dict'] = ps_optim.state_dict()
        state['epoch'] = epoch
        state['valid_metric'] = valid_synth_f1
        state['valid_f1_psds'] = psds_m_f1

        # Callbacks
        if cfg.checkpoint_epochs is not None and (epoch + 1) % cfg.checkpoint_epochs == 0:
            model_fname = os.path.join(saved_model_dir, "ms_epoch_" + str(epoch))
            torch.save(state, model_fname)