示例#1
0
def get_f_measure_by_class(outputs, nb_tags, threshold=None):
    TP = np.zeros(nb_tags)
    TN = np.zeros(nb_tags)
    FP = np.zeros(nb_tags)
    FN = np.zeros(nb_tags)

    binarization_type = 'global_threshold'
    probability_encoder = ProbabilityEncoder()
    threshold = 0.5 if not threshold else threshold
    for predictions, utt_targets in outputs:
        predictions = probability_encoder.binarization(predictions,
                                                       binarization_type=binarization_type,
                                                       threshold=threshold,
                                                       time_axis=0
                                                       )
        TP += (predictions + utt_targets == 2).sum(axis=0)
        FP += (predictions - utt_targets == 1).sum(axis=0)
        FN += (utt_targets - predictions == 1).sum(axis=0)
        TN += (predictions + utt_targets == 0).sum(axis=0)

    macro_f_measure = np.zeros(nb_tags)
    mask_f_score = 2*TP + FP + FN != 0
    macro_f_measure[mask_f_score] = 2 * \
        TP[mask_f_score] / (2*TP + FP + FN)[mask_f_score]

    return macro_f_measure
示例#2
0
def get_predictions(model, valid_dataset, decoder, save_predictions=None):
    for i, (input, _) in enumerate(valid_dataset):
        [input] = to_cuda_if_available([input])

        pred_strong, _ = model(input.unsqueeze(0))
        pred_strong = pred_strong.cpu()
        pred_strong = pred_strong.squeeze(0).detach().numpy()
        if i == 0:
            LOG.debug(pred_strong)
        pred_strong = ProbabilityEncoder().binarization(pred_strong, binarization_type="global_threshold",
                                                        threshold=0.5)
        pred_strong = scipy.ndimage.filters.median_filter(pred_strong, (cfg.median_window, 1))
        pred = decoder(pred_strong)
        pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"])
        pred["filename"] = valid_dataset.filenames.iloc[i]
        if i == 0:
            LOG.debug("predictions: \n{}".format(pred))
            LOG.debug("predictions strong: \n{}".format(pred_strong))
            prediction_df = pred.copy()
        else:
            prediction_df = prediction_df.append(pred)

    if save_predictions is not None:
        LOG.info("Saving predictions at: {}".format(save_predictions))
        prediction_df.to_csv(save_predictions, index=False, sep="\t")
    return prediction_df
示例#3
0
def get_weak_predictions(model,
                         valid_dataset,
                         weak_decoder,
                         save_predictions=None):
    for i, (data, _) in enumerate(valid_dataset):
        data = to_cuda_if_available(data)

        pred_weak = model(data.unsqueeze(0))
        pred_weak = pred_weak.cpu()
        pred_weak = pred_weak.squeeze(0).detach().numpy()
        if i == 0:
            LOG.debug(pred_weak)
        pred_weak = ProbabilityEncoder().binarization(
            pred_weak, binarization_type="global_threshold", threshold=0.5)
        pred = weak_decoder(pred_weak)
        pred = pd.DataFrame(pred, columns=["event_labels"])
        pred["filename"] = valid_dataset.filenames.iloc[i]
        if i == 0:
            LOG.debug("predictions: \n{}".format(pred))
            prediction_df = pred.copy()
        else:
            prediction_df = prediction_df.append(pred)

    if save_predictions is not None:
        LOG.info("Saving predictions at: {}".format(save_predictions))
        prediction_df.to_csv(save_predictions, index=False, sep="\t")
    return prediction_df
示例#4
0
def get_f_measure_by_class(keras_model, nb_tags, generator, steps, thresholds=None):
    """ get f measure for each class given a model and a generator of data (X, y)
    Parameters
    ----------
    keras_model : Model, model to get predictions
    nb_tags : int, number of classes which are represented
    generator : generator, data generator used to get f_measure
    steps : int, number of steps the generator will be used before stopping
    thresholds : int or list, thresholds to apply to each class to binarize probabilities
    Return
    ------
    macro_f_measure : list, f measure for each class
    """

    # Calculate external metrics
    TP = numpy.zeros(nb_tags)
    TN = numpy.zeros(nb_tags)
    FP = numpy.zeros(nb_tags)
    FN = numpy.zeros(nb_tags)
    for counter, (X, y) in enumerate(generator):
        if counter == steps:
            break
        predictions = keras_model.predict(X)

        if len(predictions.shape) == 3:
            # average data to have weak labels
            predictions = numpy.mean(predictions, axis=1)
            y = numpy.mean(y, axis=1)

        if thresholds is None:
            binarization_type = 'global_threshold'
            thresh = 0.2
        else:
            if type(thresholds) is list:
                thresh = thresholds
                binarization_type = "class_threshold"
            else:
                binarization_type = "global_threshold"
                thresh = thresholds
        predictions = ProbabilityEncoder().binarization(predictions,
                                                        binarization_type=binarization_type,
                                                        threshold=thresh,
                                                        time_axis=0
                                                        )

        TP += (predictions + y == 2).sum(axis=0)
        FP += (predictions - y == 1).sum(axis=0)
        FN += (y - predictions == 1).sum(axis=0)
        TN += (predictions + y == 0).sum(axis=0)

    macro_f_measure = numpy.zeros(nb_tags)
    mask_f_score = 2*TP + FP + FN != 0
    macro_f_measure[mask_f_score] = 2*TP[mask_f_score] / (2*TP + FP + FN)[mask_f_score]

    return numpy.mean(macro_f_measure)
示例#5
0
def generate_label(pred):
    p = []
    for b in range(pred.size(0)):
        pred_bin = ProbabilityEncoder().binarization(pred[b].detach().cpu().numpy(),
                                                     binarization_type="global_threshold",
                                                     threshold=0.5)
        p.append(pred_bin)

    p = torch.FloatTensor(p)
    p = to_cuda_if_available(p)
    return p
def get_predictions(model,
                    valid_dataloader,
                    decoder,
                    pooling_time_ratio=1,
                    median_window=1,
                    save_predictions=None):
    prediction_df = pd.DataFrame()
    for i, ((input_data, _), indexes) in enumerate(valid_dataloader):
        indexes = indexes.numpy()
        input_data = to_cuda_if_available(input_data)

        pred_strong, _ = model(input_data)
        pred_strong = pred_strong.cpu()
        pred_strong = pred_strong.detach().numpy()
        if i == 0:
            logger.debug(pred_strong)

        for j, pred_strong_it in enumerate(pred_strong):
            pred_strong_it = ProbabilityEncoder().binarization(
                pred_strong_it,
                binarization_type="global_threshold",
                threshold=0.5)
            pred_strong_it = scipy.ndimage.filters.median_filter(
                pred_strong_it, (median_window, 1))
            pred = decoder(pred_strong_it)
            pred = pd.DataFrame(pred,
                                columns=["event_label", "onset", "offset"])
            pred["filename"] = valid_dataloader.dataset.filenames.iloc[
                indexes[j]]
            prediction_df = prediction_df.append(pred)

            if i == 0 and j == 0:
                logger.debug("predictions: \n{}".format(pred))
                logger.debug("predictions strong: \n{}".format(pred_strong_it))

    # In seconds
    prediction_df.loc[:,
                      "onset"] = prediction_df.onset * pooling_time_ratio / (
                          cfg.sample_rate / cfg.hop_size)
    prediction_df.loc[:,
                      "offset"] = prediction_df.offset * pooling_time_ratio / (
                          cfg.sample_rate / cfg.hop_size)
    prediction_df = prediction_df.reset_index(drop=True)
    if save_predictions is not None:
        dir_to_create = osp.dirname(save_predictions)
        if dir_to_create != "":
            os.makedirs(dir_to_create, exist_ok=True)
        logger.info("Saving predictions at: {}".format(save_predictions))
        prediction_df.to_csv(save_predictions,
                             index=False,
                             sep="\t",
                             float_format="%.3f")
    return prediction_df
def generate_label(pred):
    p = []
    thresh = [0.45, 0.5, 0.5, 0.5, 0.5, 0.5, 0.45, 0.45, 0.5, 0.35]
    for b in range(pred.size(0)):
        pred_bin = np.array([])
        for c in range(10):
            pred_b = ProbabilityEncoder().binarization(
                pred[b][c].unsqueeze(0).cpu().detach().numpy(),
                binarization_type="global_threshold",
                threshold=thresh[c])
            pred_bin = np.concatenate((pred_bin, pred_b), axis=0)
        p.append(pred_bin)

    p = torch.FloatTensor(p)
    p = to_cuda_if_available(p)
    return p
示例#8
0
def get_f_measure_by_class(torch_model,
                           nb_tags,
                           dataloader_,
                           thresholds_=None):
    """ get f measure for each class given a model and a generator of data (batch_x, y)

    Args:
        torch_model : Model, model to get predictions, forward should return weak and strong predictions
        nb_tags : int, number of classes which are represented
        dataloader_ : generator, data generator used to get f_measure
        thresholds_ : int or list, thresholds to apply to each class to binarize probabilities

    Returns:
        macro_f_measure : list, f measure for each class

    """
    if torch.cuda.is_available():
        torch_model = torch_model.cuda()

    # Calculate external metrics
    tp = np.zeros(nb_tags)
    tn = np.zeros(nb_tags)
    fp = np.zeros(nb_tags)
    fn = np.zeros(nb_tags)
    for counter, (batch_x, y) in enumerate(dataloader_):
        if torch.cuda.is_available():
            batch_x = batch_x.cuda()

        pred_strong, pred_weak = torch_model(batch_x)
        pred_weak = pred_weak.cpu().data.numpy()
        labels = y.numpy()

        # Used only with a model predicting only strong outputs
        if len(pred_weak.shape) == 3:
            # average data to have weak labels
            pred_weak = np.max(pred_weak, axis=1)

        if len(labels.shape) == 3:
            labels = np.max(labels, axis=1)
            labels = ProbabilityEncoder().binarization(
                labels, binarization_type="global_threshold", threshold=0.5)

        if thresholds_ is None:
            binarization_type = 'global_threshold'
            thresh = 0.5
        else:
            binarization_type = "class_threshold"
            assert type(thresholds_) is list
            thresh = thresholds_

        batch_predictions = ProbabilityEncoder().binarization(
            pred_weak,
            binarization_type=binarization_type,
            threshold=thresh,
            time_axis=0)

        tp_, fp_, fn_, tn_ = intermediate_at_measures(labels,
                                                      batch_predictions)
        tp += tp_
        fp += fp_
        fn += fn_
        tn += tn_

    macro_f_score = np.zeros(nb_tags)
    mask_f_score = 2 * tp + fp + fn != 0
    macro_f_score[mask_f_score] = 2 * tp[mask_f_score] / (2 * tp + fp +
                                                          fn)[mask_f_score]

    return macro_f_score
示例#9
0
def get_predictions(model,
                    dataloader,
                    decoder,
                    pooling_time_ratio=1,
                    thresholds=[0.5],
                    median_window=1,
                    save_predictions=None):
    """ Get the predictions of a trained model on a specific set
    Args:
        model: torch.Module, a trained pytorch model (you usually want it to be in .eval() mode).
        dataloader: torch.utils.data.DataLoader, giving ((input_data, label), indexes) but label is not used here
        decoder: function, takes a numpy.array of shape (time_steps, n_labels) as input and return a list of lists
            of ("event_label", "onset", "offset") for each label predicted.
        pooling_time_ratio: the division to make between timesteps as input and timesteps as output
        median_window: int, the median window (in number of time steps) to be applied
        save_predictions: str or list, the path of the base_filename to save the predictions or a list of names
            corresponding for each thresholds
        thresholds: list, list of threshold to be applied

    Returns:
        dict of the different predictions with associated threshold
    """

    # Init a dataframe per threshold
    prediction_dfs = {}
    for threshold in thresholds:
        prediction_dfs[threshold] = pd.DataFrame()

    # Get predictions
    for i, ((input_data, _), indexes) in enumerate(dataloader):
        indexes = indexes.numpy()
        input_data = to_cuda_if_available(input_data)
        with torch.no_grad():
            pred_strong, _ = model(input_data)
        pred_strong = pred_strong.cpu()
        pred_strong = pred_strong.detach().numpy()
        if i == 0:
            logger.debug(pred_strong)

        # Post processing and put predictions in a dataframe
        for j, pred_strong_it in enumerate(pred_strong):
            for threshold in thresholds:
                pred_strong_bin = ProbabilityEncoder().binarization(
                    pred_strong_it,
                    binarization_type="global_threshold",
                    threshold=threshold)
                pred_strong_m = scipy.ndimage.filters.median_filter(
                    pred_strong_bin, (median_window, 1))
                pred = decoder(pred_strong_m)
                pred = pd.DataFrame(pred,
                                    columns=["event_label", "onset", "offset"])
                # Put them in seconds
                pred.loc[:, ["onset", "offset"]] *= pooling_time_ratio / (
                    cfg.sample_rate / cfg.hop_size)
                pred.loc[:,
                         ["onset", "offset"]] = pred[["onset", "offset"]].clip(
                             0, cfg.max_len_seconds)

                pred["filename"] = dataloader.dataset.filenames.iloc[
                    indexes[j]]
                prediction_dfs[threshold] = prediction_dfs[threshold].append(
                    pred, ignore_index=True)

                if i == 0 and j == 0:
                    logger.debug("predictions: \n{}".format(pred))
                    logger.debug(
                        "predictions strong: \n{}".format(pred_strong_it))

    # Save predictions
    if save_predictions is not None:
        if isinstance(save_predictions, str):
            if len(thresholds) == 1:
                save_predictions = [save_predictions]
            else:
                base, ext = osp.splitext(save_predictions)
                save_predictions = [
                    osp.join(base, f"{threshold:.3f}{ext}")
                    for threshold in thresholds
                ]
        else:
            assert len(save_predictions) == len(thresholds), \
                f"There should be a prediction file per threshold: len predictions: {len(save_predictions)}\n" \
                f"len thresholds: {len(thresholds)}"
            save_predictions = save_predictions

        for ind, threshold in enumerate(thresholds):
            dir_to_create = osp.dirname(save_predictions[ind])
            if dir_to_create != "":
                os.makedirs(dir_to_create, exist_ok=True)
                if ind % 10 == 0:
                    logger.info(
                        f"Saving predictions at: {save_predictions[ind]}. {ind + 1} / {len(thresholds)}"
                    )
                prediction_dfs[threshold].to_csv(save_predictions[ind],
                                                 index=False,
                                                 sep="\t",
                                                 float_format="%.3f")

    list_predictions = []
    for key in prediction_dfs:
        list_predictions.append(prediction_dfs[key])

    if len(list_predictions) == 1:
        list_predictions = list_predictions[0]

    return list_predictions
示例#10
0
    def get_prediction_dataframe(
        self,
        post_processing=None,
        save_predictions=None,
        transforms=None,
        mode="validation",
        threshold=0.5,
        binarization_type="global_threshold",
    ):
        """
        post_processing: e.g. [functools.partial(median_filt_1d, filt_span=39)]
        """
        prediction_df = pd.DataFrame()

        # Flame level
        frame_measure = [ConfusionMatrix() for i in range(len(self.labels))]
        tag_measure = ConfusionMatrix()

        for batch_idx, data in enumerate(self.data_loader):
            output = {}
            output["strong"] = data["pred_strong"].cpu().data.numpy()
            output["weak"] = data["pred_weak"].cpu().data.numpy()

            # Binarize score into predicted label
            if binarization_type == "class_threshold":
                for i in range(output["strong"].shape[0]):
                    output["strong"][i] = ProbabilityEncoder().binarization(
                        output["strong"][i],
                        binarization_type=binarization_type,
                        threshold=threshold,
                        time_axis=0,
                    )
            elif binarization_type == "global_threshold":
                output["strong"] = ProbabilityEncoder().binarization(
                    output["strong"],
                    binarization_type=binarization_type,
                    threshold=threshold,
                )
            else:
                raise ValueError(
                    "binarization_type must be 'class_threshold' or 'global_threshold'"
                )
            weak = ProbabilityEncoder().binarization(
                output["weak"],
                binarization_type="global_threshold",
                threshold=0.5)

            for pred, data_id in zip(output["strong"], data["data_id"]):
                # Apply post processing if exists
                if post_processing is not None:
                    for post_process_fn in post_processing:
                        pred = post_process_fn(pred)

                pred = self.decoder(pred)
                pred = pd.DataFrame(pred,
                                    columns=["event_label", "onset", "offset"])
                # Put them in seconds
                pred.loc[:, ["onset", "offset"]] *= self.pooling_time_ratio / (
                    self.sample_rate / self.hop_size)
                pred.loc[:,
                         ["onset", "offset"]] = pred[["onset", "offset"]].clip(
                             0, self.options.max_len_seconds)
                pred["filename"] = data_id
                prediction_df = prediction_df.append(pred, ignore_index=True)

        return prediction_df
示例#11
0
def evaluate_threshold(
        model_path: str, features: str = "features/logmel_64/test.ark",
        result_filename='dev.txt',
        test_labels:
        str = "metadata/test/test.csv",
        threshold=0.5,
        window=1,
        hop_size=0.02):
    from dcase_util.data import ProbabilityEncoder, DecisionEncoder, ManyHotEncoder
    from dcase_util.containers import MetaDataContainer
    from scipy.signal import medfilt
    modeldump = torch.load(
        model_path,
        map_location=lambda storage, loc: storage)
    model = modeldump['model']
    config_parameters = modeldump['config']
    scaler = modeldump['scaler']
    many_hot_encoder = modeldump['encoder']
    model_dirname = os.path.dirname(model_path)
    meta_container_resultfile = os.path.join(
        model_dirname, "pred_nowindow.txt")
    metacontainer = MetaDataContainer(filename=meta_container_resultfile)

    kaldi_string = parsecopyfeats(
        features, **config_parameters['feature_args'])
    model = model.to(device).eval()

    probability_encoder = ProbabilityEncoder()
    decision_encoder = DecisionEncoder(
        label_list=many_hot_encoder.label_list
    )
    binarization_type = 'global_threshold' if isinstance(
        threshold, float) else 'class_threshold'
    # If class thresholds are given, then use those
    if isinstance(threshold, str):
        threshold = torch.load(threshold)
    windows = {k: window for k in many_hot_encoder.label_list}
    if isinstance(window, str):
        windows = torch.load(window)

    with torch.no_grad():
        for k, feat in kaldi_io.read_mat_ark(kaldi_string):
            # Add batch dim
            feat = torch.from_numpy(
                scaler.transform(feat)).to(device).unsqueeze(0)
            feat = model(feat)
            probabilities = torch.sigmoid(feat).cpu().numpy().squeeze(0)
            frame_decisions = probability_encoder.binarization(
                probabilities=probabilities,
                binarization_type=binarization_type,
                threshold=threshold,
                time_axis=0,
            )
            for i, label in enumerate(many_hot_encoder.label_list):
                label_frame_decisions = medfilt(
                    frame_decisions[:, i], kernel_size=windows[label])
                # Found only zeros, no activity, go on
                if (label_frame_decisions == 0).all():
                    continue
                estimated_events = decision_encoder.find_contiguous_regions(
                    activity_array=label_frame_decisions
                )
                for [onset, offset] in estimated_events:
                    metacontainer.append({'event_label': label,
                                          'onset': onset * hop_size,
                                          'offset': offset * hop_size,
                                          'filename': os.path.basename(k)
                                          })
    metacontainer.save()
    estimated_event_list = MetaDataContainer().load(
        filename=meta_container_resultfile)
    reference_event_list = MetaDataContainer().load(filename=test_labels)

    event_based_metric = event_based_evaluation(
        reference_event_list, estimated_event_list)
    onset_scores = precision_recall_fscore_on_offset(
        reference_event_list, estimated_event_list, offset=False)
    offset_scores = precision_recall_fscore_on_offset(
        reference_event_list, estimated_event_list, onset=False)
    onset_offset_scores = precision_recall_fscore_on_offset(
        reference_event_list, estimated_event_list)
    # Utt wise Accuracy
    precision_labels = precision_recall_fscore_on_offset(
        reference_event_list, estimated_event_list, onset=False, offset=False, label=True)

    print(event_based_metric.__str__())
    print("{:>10}-Precision: {:.1%} Recall {:.1%} F-Score {:.1%}".format("UttLabel", *precision_labels))
    print("{:>10}-Precision: {:.1%} Recall {:.1%} F-Score {:.1%}".format("Onset", *onset_scores))
    print("{:>10}-Precision: {:.1%} Recall {:.1%} F-Score {:.1%}".format("Offset", *offset_scores))
    print("{:>10}-Precision: {:.1%} Recall {:.1%} F-Score {:.1%}".format("On-Offset", *onset_offset_scores))

    result_filename = os.path.join(model_dirname, result_filename)

    with open(result_filename, 'w') as wp:
        wp.write(event_based_metric.__str__())
        wp.write('\n')
        wp.write("{:>10}: Precision: {:.1%} Recall {:.1%} F-Score {:.1%}\n".format(
            "UttLabel", *precision_labels))
        wp.write(
            "{:>10}: Precision: {:.1%} Recall {:.1%} F-Score {:.1%}\n".format("Onset", *onset_scores))
        wp.write(
            "{:>10}: Precision: {:.1%} Recall {:.1%} F-Score {:.1%}\n".format("Offset", *offset_scores))
        wp.write("{:>10}: Precision: {:.1%} Recall {:.1%} F-Score {:.1%}\n".format(
            "On-Offset", *onset_offset_scores))
示例#12
0
def get_f_measure_by_class(torch_model,
                           nb_tags,
                           dataloader_,
                           thresholds_=None,
                           max=False):
    """ get f measure for each class given a model and a generator of data (batch_x, y)

    Args:
        torch_model : Model, model to get predictions, forward should return weak and strong predictions
        nb_tags : int, number of classes which are represented
        dataloader_ : generator, data generator used to get f_measure
        thresholds_ : int or list, thresholds to apply to each class to binarize probabilities
        max: bool, whether or not to take the max of the predictions

    Returns:
        macro_f_measure : list, f measure for each class

    """
    torch_model = to_cuda_if_available(torch_model)

    # Calculate external metrics
    tp = np.zeros(nb_tags)
    tn = np.zeros(nb_tags)
    fp = np.zeros(nb_tags)
    fn = np.zeros(nb_tags)
    for counter, (batch_x, y) in enumerate(dataloader_):
        if torch.cuda.is_available():
            batch_x = batch_x.cuda()

        pred_weak = torch_model(batch_x)
        pred_weak = pred_weak.cpu().data.numpy()
        labels = y.numpy()

        # Used only with a model predicting only strong outputs
        if len(pred_weak.shape) == 3:
            # Max because indicate the presence, give weak labels
            pred_weak = np.max(pred_weak, axis=1)

        if len(labels.shape) == 3:
            labels = np.max(labels, axis=1)
            labels = ProbabilityEncoder().binarization(
                labels, binarization_type="global_threshold", threshold=0.5)
        if counter == 0:
            LOG.info(
                f"shapes, input: {batch_x.shape}, output: {pred_weak.shape}, label: {labels.shape}"
            )

        if not max:
            if thresholds_ is None:
                binarization_type = 'global_threshold'
                thresh = 0.5
            else:
                binarization_type = "class_threshold"
                assert type(thresholds_) is list
                thresh = thresholds_

            batch_predictions = ProbabilityEncoder().binarization(
                pred_weak,
                binarization_type=binarization_type,
                threshold=thresh,
                time_axis=0)
        else:
            batch_predictions = np.zeros(pred_weak.shape)
            batch_predictions[:, pred_weak.argmax(1)] = 1

        tp_, fp_, fn_, tn_ = intermediate_at_measures(labels,
                                                      batch_predictions)
        tp += tp_
        fp += fp_
        fn += fn_
        tn += tn_

    print("Macro measures: TP: {}\tFP: {}\tFN: {}\tTN: {}".format(
        tp, fp, fn, tn))

    macro_f_score = np.zeros(nb_tags)
    mask_f_score = 2 * tp + fp + fn != 0
    macro_f_score[mask_f_score] = 2 * tp[mask_f_score] / (2 * tp + fp +
                                                          fn)[mask_f_score]

    return macro_f_score
示例#13
0
    def validation(self, post_processing=None) -> None:
        self.strong_losses.reset()
        self.weak_losses.reset()

        prediction_df = pd.DataFrame()
        threshold = self.options.threshold
        decoder = self.options.decoder
        binarization_type = self.options.binarization_type
        post_processing = None

        ptr = self.options.pooling_time_ratio

        # Frame level measure
        frame_measure = [
            ConfusionMatrix() for i in range(len(self.options.classes))
        ]
        tag_measure = ConfusionMatrix()

        self.model.eval()

        for (data, target, data_ids) in self.valid_loader:
            data, target = data.to(self.device), target.to(self.device)
            predicts = self.model(data)

            # compute classification loss
            strong_class_loss = self.classification_criterion(
                predicts["strong"], target)
            weak_class_loss = self.classification_criterion(
                predicts["weak"],
                target.max(dim=1)[0])
            self.strong_losses.update(strong_class_loss.item())
            self.weak_losses.update(weak_class_loss.item())

            predicts["strong"] = torch.sigmoid(
                predicts["strong"]).cpu().data.numpy()
            predicts["weak"] = torch.sigmoid(
                predicts["weak"]).cpu().data.numpy()

            if binarization_type == "class_threshold":
                for i in range(predicts["strong"].shape[0]):
                    predicts["strong"][i] = ProbabilityEncoder().binarization(
                        predicts["strong"][i],
                        binarization_type=binarization_type,
                        threshold=threshold,
                        time_axis=0,
                    )
            else:
                predicts["strong"] = ProbabilityEncoder().binarization(
                    predicts["strong"],
                    binarization_type=binarization_type,
                    threshold=threshold,
                )
                predicts["weak"] = ProbabilityEncoder().binarization(
                    predicts["weak"],
                    binarization_type=binarization_type,
                    threshold=threshold,
                )

            # For debug, frame level measure
            for i in range(len(predicts["strong"])):
                target_np = target.cpu().numpy()
                tn, fp, fn, tp = confusion_matrix(target_np[i].max(axis=0),
                                                  predicts["weak"][i],
                                                  labels=[0, 1]).ravel()
                tag_measure.add_cf(tn, fp, fn, tp)
                for j in range(len(self.options.classes)):
                    tn, fp, fn, tp = confusion_matrix(target_np[i][:, j],
                                                      predicts["strong"][i][:,
                                                                            j],
                                                      labels=[0, 1]).ravel()
                    frame_measure[j].add_cf(tn, fp, fn, tp)

            if post_processing is not None:
                for i in range(predicts["strong"].shape[0]):
                    for post_process_fn in post_processing:
                        predicts["strong"][i] = post_process_fn(
                            predicts["strong"][i])

            for pred, data_id in zip(predicts["strong"], data_ids):
                pred = decoder(pred)
                pred = pd.DataFrame(pred,
                                    columns=["event_label", "onset", "offset"])

                # Put them in seconds
                pred.loc[:, ["onset", "offset"]] *= ptr / (
                    self.options.sample_rate / self.options.hop_size)
                pred.loc[:,
                         ["onset", "offset"]] = pred[["onset", "offset"]].clip(
                             0, self.options.max_len_seconds)

                pred["filename"] = data_id
                prediction_df = prediction_df.append(pred)

        else:
            # save predictions
            prediction_df.to_csv(
                self.exp_name / "predictions" /
                f"{self.forward_count}th_iterations.csv",
                index=False,
                sep="\t",
                float_format="%.3f",
            )

            # Compute evaluation metrics
            events_metric, segments_metric, psds_m_f1 = compute_metrics(
                prediction_df,
                self.options.validation_df,
                self.options.durations_validation,
            )
            macro_f1_event = events_metric.results_class_wise_average_metrics(
            )["f_measure"]["f_measure"]
            macro_f1_segment = segments_metric.results_class_wise_average_metrics(
            )["f_measure"]["f_measure"]

            # Compute frame level macro f1 score
            ave_precision = 0
            ave_recall = 0
            macro_f1 = 0
            for i in range(len(self.options.classes)):
                ave_precision_, ave_recall_, macro_f1_ = frame_measure[
                    i].calc_f1()
                ave_precision += ave_precision_
                ave_recall += ave_recall_
                macro_f1 += macro_f1_
            ave_precision /= len(self.options.classes)
            ave_recall /= len(self.options.classes)
            macro_f1 /= len(self.options.classes)
            weak_f1 = tag_measure.calc_f1()[2]

            metrics = {
                "valid_strong_loss": self.strong_losses.avg,
                "valid_weak_loss": self.weak_losses.avg,
                "event_m_f1": macro_f1_event,
                "segment_m_f1": macro_f1_segment,
                "psds_m_f1": psds_m_f1,
                "frame_level_precision": ave_precision,
                "frame_level_recall": ave_recall,
                "frame_level_macro_f1": macro_f1,
                "weak_f1": weak_f1,
            }

            wandb.log(metrics, step=self.forward_count)

        return metrics