def _load_model(state, model_name="PS"):
    ps_args = state[model_name]["args"]
    ps_kwargs = state[model_name]["kwargs"]
    ps = PS(*ps_args, **ps_kwargs)
    ps.load_state_dict(state[model_name]["state_dict"])
    ps.eval()
    ps = to_cuda_if_available(ps)

    ms_args = state["MS"]["args"]
    ms_kwargs = state["MS"]["kwargs"]
    ms = MS()
    ms.load_state_dict(state["MS"]["state_dict"])
    ms.eval()
    ms = to_cuda_if_available(ms)

    ema_args = state["model_ema"]["args"]
    ema_kwargs = state["model_ema"]["kwargs"]
    ema = MS()
    ema.load_state_dict(state["model_ema"]["state_dict"])
    ema.eval()
    ema = to_cuda_if_available(ema)

    logger.info("Model loaded at epoch: {}".format(state["epoch"]))
    logger.info(ps)
    return ps, ms, ema
    def forward(self, x):
        # input size : (batch_size, n_channels, n_frames, n_freq)
        # conv features
        x = self.cnn(x)
        bs, chan, frames, freq = x.size()
        if freq != 1:
            warnings.warn(
                f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq"
            )
            x = x.permute(0, 2, 1, 3)
            x = x.contiguous().view(bs, frames, chan * freq)
        else:
            x = x.squeeze(-1)
            x = x.permute(0, 2, 1)  # [bs, frames, chan]

        weak = torch.FloatTensor([])
        weak = to_cuda_if_available(weak)
        if self.attention:
            for i in range(self.nclass):
                # attention weight
                sof = torch.sum(x * self.weights[i],
                                dim=-1) + self.bias[i]  # [bs, frames, nclass]
                sof = sof / 256
                sof = self.softmax(sof)

                # contexual representation
                cr = torch.matmul(sof.unsqueeze(1), x).squeeze(1)

                # audio tagging
                at = self.clf[i](cr)
                at = nn.Sigmoid()(at)
                weak = torch.cat((weak, at), dim=-1)

        return weak
    def forward(self, x):
        # input size : (batch_size, n_channels, n_frames, n_freq)
        # conv features
        x = self.cnn(x)

        bs, chan, frames, freq = x.size()
        if freq != 1:
            warnings.warn(
                f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq"
            )
            x = x.permute(0, 2, 1, 3)
            x = x.contiguous().view(bs, frames, chan * freq)
        else:
            x = x.squeeze(-1)
            x = x.permute(0, 2, 1)  # [bs, frames, chan]
        cnn_x = x

        weak = torch.FloatTensor([])
        weak = to_cuda_if_available(weak)
        strong = torch.FloatTensor([])
        strong = to_cuda_if_available(strong)

        # ATP
        if self.attention:
            for i in range(self.nclass):
                x_c = x[:, :, :self.df[i]]
                # attention weight
                sof = torch.sum(x_c * self.weights[i],
                                dim=-1) + self.bias[i]  # [bs, frames, nclass]
                sed = self.sigmoid(sof).unsqueeze(-1)
                strong = torch.cat((strong, sed), dim=-1)
                sof = sof / self.df[i]
                sof = self.softmax(sof)

                # contexual representation
                cr = torch.matmul(sof.unsqueeze(1), x_c).squeeze(1)

                # audio tagging
                at = self.clf[i](cr)
                at = nn.Sigmoid()(at)
                weak = torch.cat((weak, at), dim=-1)

            phi = generate_label(weak).unsqueeze(1)
            strong = strong * phi

        return strong, weak, cnn_x
Esempio n. 4
0
def _load_crnn(state, model_name="model"):
    crnn_args = state[model_name]["args"]
    crnn_kwargs = state[model_name]["kwargs"]
    crnn = CRNN(*crnn_args, **crnn_kwargs)
    crnn.load_state_dict(state[model_name]["state_dict"])
    crnn.eval()
    crnn = to_cuda_if_available(crnn)
    logger.info("Model loaded at epoch: {}".format(state["epoch"]))
    logger.info(crnn)
    return crnn
Esempio n. 5
0
def generate_label(pred):
    p = []
    for b in range(pred.size(0)):
        pred_bin = ProbabilityEncoder().binarization(pred[b].detach().cpu().numpy(),
                                                     binarization_type="global_threshold",
                                                     threshold=0.5)
        p.append(pred_bin)

    p = torch.FloatTensor(p)
    p = to_cuda_if_available(p)
    return p
def get_predictions(model,
                    valid_dataloader,
                    decoder,
                    pooling_time_ratio=1,
                    median_window=1,
                    save_predictions=None):
    prediction_df = pd.DataFrame()
    for i, ((input_data, _), indexes) in enumerate(valid_dataloader):
        indexes = indexes.numpy()
        input_data = to_cuda_if_available(input_data)

        pred_strong, _ = model(input_data)
        pred_strong = pred_strong.cpu()
        pred_strong = pred_strong.detach().numpy()
        if i == 0:
            logger.debug(pred_strong)

        for j, pred_strong_it in enumerate(pred_strong):
            pred_strong_it = ProbabilityEncoder().binarization(
                pred_strong_it,
                binarization_type="global_threshold",
                threshold=0.5)
            pred_strong_it = scipy.ndimage.filters.median_filter(
                pred_strong_it, (median_window, 1))
            pred = decoder(pred_strong_it)
            pred = pd.DataFrame(pred,
                                columns=["event_label", "onset", "offset"])
            pred["filename"] = valid_dataloader.dataset.filenames.iloc[
                indexes[j]]
            prediction_df = prediction_df.append(pred)

            if i == 0 and j == 0:
                logger.debug("predictions: \n{}".format(pred))
                logger.debug("predictions strong: \n{}".format(pred_strong_it))

    # In seconds
    prediction_df.loc[:,
                      "onset"] = prediction_df.onset * pooling_time_ratio / (
                          cfg.sample_rate / cfg.hop_size)
    prediction_df.loc[:,
                      "offset"] = prediction_df.offset * pooling_time_ratio / (
                          cfg.sample_rate / cfg.hop_size)
    prediction_df = prediction_df.reset_index(drop=True)
    if save_predictions is not None:
        dir_to_create = osp.dirname(save_predictions)
        if dir_to_create != "":
            os.makedirs(dir_to_create, exist_ok=True)
        logger.info("Saving predictions at: {}".format(save_predictions))
        prediction_df.to_csv(save_predictions,
                             index=False,
                             sep="\t",
                             float_format="%.3f")
    return prediction_df
def generate_label(pred):
    p = []
    thresh = [0.45, 0.5, 0.5, 0.5, 0.5, 0.5, 0.45, 0.45, 0.5, 0.35]
    for b in range(pred.size(0)):
        pred_bin = np.array([])
        for c in range(10):
            pred_b = ProbabilityEncoder().binarization(
                pred[b][c].unsqueeze(0).cpu().detach().numpy(),
                binarization_type="global_threshold",
                threshold=thresh[c])
            pred_bin = np.concatenate((pred_bin, pred_b), axis=0)
        p.append(pred_bin)

    p = torch.FloatTensor(p)
    p = to_cuda_if_available(p)
    return p
Esempio n. 8
0
def train(train_loader,
          model,
          optimizer,
          c_epoch,
          ema_model=None,
          mask_weak=None,
          mask_strong=None,
          adjust_lr=False):
    """ One epoch of a Mean Teacher model
    Args:
        train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch.
            Should return a tuple: ((teacher input, student input), labels)
        model: torch.Module, model to be trained, should return a weak and strong prediction
        optimizer: torch.Module, optimizer used to train the model
        c_epoch: int, the current epoch of training
        ema_model: torch.Module, student model, should return a weak and strong prediction
        mask_weak: slice or list, mask the batch to get only the weak labeled data (used to calculate the loss)
        mask_strong: slice or list, mask the batch to get only the strong labeled data (used to calcultate the loss)
        adjust_lr: bool, Whether or not to adjust the learning rate during training (params in config)
    """
    log = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name,
                        terminal_level=cfg.terminal_level)
    class_criterion = nn.BCELoss()
    consistency_criterion = nn.MSELoss()
    class_criterion, consistency_criterion = to_cuda_if_available(
        class_criterion, consistency_criterion)

    meters = AverageMeterSet()
    log.debug("Nb batches: {}".format(len(train_loader)))
    start = time.time()
    for i, ((batch_input, ema_batch_input), target) in enumerate(train_loader):
        global_step = c_epoch * len(train_loader) + i
        rampup_value = ramps.exp_rampup(global_step,
                                        cfg.n_epoch_rampup * len(train_loader))

        if adjust_lr:
            adjust_learning_rate(optimizer, rampup_value)
        meters.update('lr', optimizer.param_groups[0]['lr'])
        batch_input, ema_batch_input, target = to_cuda_if_available(
            batch_input, ema_batch_input, target)
        # Outputs
        strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input)
        strong_pred_ema = strong_pred_ema.detach()
        weak_pred_ema = weak_pred_ema.detach()
        strong_pred, weak_pred = model(batch_input)

        loss = None
        # Weak BCE Loss
        target_weak = target.max(-2)[0]  # Take the max in the time axis
        if mask_weak is not None:
            weak_class_loss = class_criterion(weak_pred[mask_weak],
                                              target_weak[mask_weak])
            ema_class_loss = class_criterion(weak_pred_ema[mask_weak],
                                             target_weak[mask_weak])
            loss = weak_class_loss

            if i == 0:
                log.debug(
                    f"target: {target.mean(-2)} \n Target_weak: {target_weak} \n "
                    f"Target weak mask: {target_weak[mask_weak]} \n "
                    f"Target strong mask: {target[mask_strong].sum(-2)}\n"
                    f"weak loss: {weak_class_loss} \t rampup_value: {rampup_value}"
                    f"tensor mean: {batch_input.mean()}")
            meters.update('weak_class_loss', weak_class_loss.item())
            meters.update('Weak EMA loss', ema_class_loss.item())

        # Strong BCE loss
        if mask_strong is not None:
            strong_class_loss = class_criterion(strong_pred[mask_strong],
                                                target[mask_strong])
            meters.update('Strong loss', strong_class_loss.item())

            strong_ema_class_loss = class_criterion(
                strong_pred_ema[mask_strong], target[mask_strong])
            meters.update('Strong EMA loss', strong_ema_class_loss.item())

            if loss is not None:
                loss += strong_class_loss
            else:
                loss = strong_class_loss

        # Teacher-student consistency cost
        if ema_model is not None:
            consistency_cost = cfg.max_consistency_cost * rampup_value
            meters.update('Consistency weight', consistency_cost)
            # Take consistency about strong predictions (all data)
            consistency_loss_strong = consistency_cost * consistency_criterion(
                strong_pred, strong_pred_ema)
            meters.update('Consistency strong', consistency_loss_strong.item())
            if loss is not None:
                loss += consistency_loss_strong
            else:
                loss = consistency_loss_strong
            meters.update('Consistency weight', consistency_cost)
            # Take consistency about weak predictions (all data)
            consistency_loss_weak = consistency_cost * consistency_criterion(
                weak_pred, weak_pred_ema)
            meters.update('Consistency weak', consistency_loss_weak.item())
            if loss is not None:
                loss += consistency_loss_weak
            else:
                loss = consistency_loss_weak

        assert not (np.isnan(loss.item())
                    or loss.item() > 1e5), 'Loss explosion: {}'.format(
                        loss.item())
        assert not loss.item() < 0, 'Loss problem, cannot be negative'
        meters.update('Loss', loss.item())

        # compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step += 1
        if ema_model is not None:
            update_ema_variables(model, ema_model, 0.999, global_step)

    epoch_time = time.time() - start
    log.info(f"Epoch: {c_epoch}\t Time {epoch_time:.2f}\t {meters}")
    return loss, meters
Esempio n. 9
0
    save_best_cb = SaveBest("sup")
    if cfg.early_stopping is not None:
        early_stopping_call = EarlyStopping(patience=cfg.early_stopping,
                                            val_comp="sup",
                                            init_patience=cfg.es_init_wait)

    # ##############
    # Train
    # ##############
    results = pd.DataFrame(
        columns=["loss", "valid_synth_f1", "weak_metric", "global_valid"])
    for epoch in range(cfg.n_epoch):
        crnn.train()
        crnn_ema.train()
        crnn, crnn_ema = to_cuda_if_available(crnn, crnn_ema)

        loss_value, meters = train(training_loader,
                                   crnn,
                                   optim,
                                   epoch,
                                   ema_model=crnn_ema,
                                   mask_weak=weak_mask,
                                   mask_strong=strong_mask,
                                   adjust_lr=cfg.adjust_lr)

        # Validation
        crnn = crnn.eval()
        logger.info("\n ### Valid synthetic metric ### \n")
        predictions = get_predictions(crnn,
                                      valid_synth_loader,
Esempio n. 10
0
    if cfg.early_stopping is not None:
        early_stopping_call = EarlyStopping(patience=cfg.early_stopping, val_comp="sup", init_patience=cfg.es_init_wait)

    # in case of not using ema_model
    wcrnn_ema = None


    # ##############
    # Train
    # ##############
    results = pd.DataFrame(columns=["loss", "valid_synth_f1", "weak_metric", "global_valid"])
    for epoch in range(cfg.n_epoch):
        wcrnn.train()
        #wcrnn_ema.train()
        #wcrnn, wcrnn_ema = to_cuda_if_available(wcrnn, wcrnn_ema)
        wcrnn = to_cuda_if_available(wcrnn)

        loss_value = train(training_loader, wcrnn, optim, epoch,
                           ema_model=None, mask_weak=weak_mask, mask_strong=strong_mask, adjust_lr=cfg.adjust_lr)
                           #ema_model=wcrnn_ema, mask_weak=weak_mask, mask_strong=strong_mask, adjust_lr=cfg.adjust_lr)
        # Validation
        wcrnn = wcrnn.eval()
        logger.info("\n ### Valid synthetic metric ### \n")
        predictions = get_predictions(wcrnn, valid_synth_loader, many_hot_encoder.decode_strong, pooling_time_ratio,
                                      median_window=median_window, save_predictions=None)
        # Validation with synthetic data (dropping feature_filename for psds)
        valid_synth = dfs["valid_synthetic"].drop("feature_filename", axis=1)
        valid_synth_f1, psds_m_f1 = compute_metrics(predictions, valid_synth, durations_synth)

        # Update state
        state['model']['state_dict'] = wcrnn.state_dict()
def get_predictions(model, dataloader, decoder, pooling_time_ratio=1, thresholds=[0.5],
                    median_window=1, save_predictions=None):
    """ Get the predictions of a trained model on a specific set
    Args:
        model: torch.Module, a trained pytorch model (you usually want it to be in .eval() mode).
        dataloader: torch.utils.data.DataLoader, giving ((input_data, label), indexes) but label is not used here
        decoder: function, takes a numpy.array of shape (time_steps, n_labels) as input and return a list of lists
            of ("event_label", "onset", "offset") for each label predicted.
        pooling_time_ratio: the division to make between timesteps as input and timesteps as output
        median_window: int, the median window (in number of time steps) to be applied
        save_predictions: str or list, the path of the base_filename to save the predictions or a list of names
            corresponding for each thresholds
        thresholds: list, list of threshold to be applied

    Returns:
        dict of the different predictions with associated threshold
    """

    # Init a dataframe per threshold
    prediction_dfs = {}
    for threshold in thresholds:
        # print('threshold', threshold)
        prediction_dfs[threshold] = pd.DataFrame()

        # Transitions
    transitions  = get_transitions([0.8936181121106661,
                                    0.31550391085547114, 0.8923570974604905,
                                    0.6490908582453855, 0.949481580049081,
                                    0.7288818269203898, 0.4234812116634173,
                                    0.5027625909657499, 0.8155335690232651,
                                    0.6838988940188258, 0.6928248693985566])

    # Get predictions
    for i, ((input_data, _), indexes) in enumerate(dataloader):
        indexes = indexes.numpy()
        input_data = to_cuda_if_available(input_data)
        with torch.no_grad():
            pred_strong, _ = model(input_data)
        pred_strong = pred_strong.cpu()
        pred_strong = pred_strong.detach().numpy()
        if i == 0:
            logger.debug(pred_strong)

        # Post processing and put predictions in a dataframe
        for j, pred_strong_it in enumerate(pred_strong):
            for threshold in thresholds:
                # pred_strong_bin = ProbabilityEncoder().binarization(pred_strong_it,
                #                                                     binarization_type="global_threshold",
                #                                                     threshold=threshold)

                pred_strong_bin = apply_HMMs(pred_strong_it, transitions)
                # print("PRED_STRONG:", pred_strong_bin)
                # pred_strong_m = scipy.ndimage.filters.median_filter(pred_strong_bin, (median_window, 1))

                pred_strong_m = pred_strong_bin

                pred = decoder(pred_strong_m)
                pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"])
                # Put them in seconds
                pred.loc[:, ["onset", "offset"]] *= pooling_time_ratio / (cfg.sample_rate / cfg.hop_size)
                pred.loc[:, ["onset", "offset"]] = pred[["onset", "offset"]].clip(0, cfg.max_len_seconds)

                pred["filename"] = dataloader.dataset.filenames.iloc[indexes[j]]
                prediction_dfs[threshold] = prediction_dfs[threshold].append(pred, ignore_index=True)

                if i == 0 and j == 0:
                    logger.debug("predictions: \n{}".format(pred))
                    logger.debug("predictions strong: \n{}".format(pred_strong_it))

    # Save predictions
    if save_predictions is not None:
        if isinstance(save_predictions, str):
            if len(thresholds) == 1:
                save_predictions = [save_predictions]
            else:
                base, ext = osp.splitext(save_predictions)
                save_predictions = [osp.join(base, f"{threshold:.3f}{ext}") for threshold in thresholds]
        else:
            assert len(save_predictions) == len(thresholds), \
                f"There should be a prediction file per threshold: len predictions: {len(save_predictions)}\n" \
                f"len thresholds: {len(thresholds)}"
            save_predictions = save_predictions

        for ind, threshold in enumerate(thresholds):
            dir_to_create = osp.dirname(save_predictions[ind])
            if dir_to_create != "":
                os.makedirs(dir_to_create, exist_ok=True)
                if ind % 10 == 0:
                    logger.info(f"Saving predictions at: {save_predictions[ind]}. {ind + 1} / {len(thresholds)}")
                prediction_dfs[threshold].to_csv(save_predictions[ind], index=False, sep="\t", float_format="%.3f")

    list_predictions = []
    for key in prediction_dfs:
        list_predictions.append(prediction_dfs[key])

    if len(list_predictions) == 1:
        list_predictions = list_predictions[0]

    return list_predictions
Esempio n. 12
0
    save_best_cb = SaveBest("sup")
    if cfg.early_stopping is not None:
        early_stopping_call = EarlyStopping(patience=cfg.early_stopping,
                                            val_comp="sup",
                                            init_patience=cfg.es_init_wait)

    # ##############
    # Train
    # ##############
    results = pd.DataFrame(
        columns=["loss", "valid_synth_f1", "weak_metric", "global_valid"])
    for epoch in range(cfg.n_epoch):
        wcrnn.train()
        wcrnn_ema.train()
        wcrnn, wcrnn_ema = to_cuda_if_available(wcrnn, wcrnn_ema)

        loss_value = train(training_loader,
                           wcrnn,
                           optim,
                           epoch,
                           ema_model=wcrnn_ema,
                           mask_weak=weak_mask,
                           mask_strong=strong_mask,
                           adjust_lr=cfg.adjust_lr)
        # Validation
        wcrnn = wcrnn.eval()
        logger.info("\n ### Valid synthetic metric ### \n")
        predictions = get_predictions(wcrnn,
                                      valid_synth_loader,
                                      many_hot_encoder.decode_strong,
def get_predictions_ss_late_integration(model, valid_dataload, decoder, pooling_time_ratio=1, thresholds=[0.5],
                                        median_window=1, save_predictions=None, alpha=1):
    """ Get the predictions of a trained model on a specific set
    Args:
        model: torch.Module, a trained pytorch model (you usually want it to be in .eval() mode).
        valid_dataload: DataLoadDf, giving ((input_data, label), index) but label is not used here, the multiple
            data are the multiple sources (the mixture should always be the first one to appear, and then the sources)
            example: if the input data is (3, 1, timesteps, freq) there is the mixture and 2 sources.
        decoder: function, takes a numpy.array of shape (time_steps, n_labels) as input and return a list of lists
            of ("event_label", "onset", "offset") for each label predicted.
        pooling_time_ratio: the division to make between timesteps as input and timesteps as output
        median_window: int, the median window (in number of time steps) to be applied
        save_predictions: str or list, the path of the base_filename to save the predictions or a list of names
            corresponding for each thresholds
        thresholds: list, list of threshold to be applied
        alpha: float, the value of the norm to combine the predictions

    Returns:
        dict of the different predictions with associated threshold
    """

    # Init a dataframe per threshold
    prediction_dfs = {}
    for threshold in thresholds:
        prediction_dfs[threshold] = pd.DataFrame()

    # Get predictions
    for i, ((input_data, _), index) in enumerate(valid_dataload):
        input_data = to_cuda_if_available(input_data)
        with torch.no_grad():
            pred_strong, _ = model(input_data)
        pred_strong = pred_strong.cpu()
        pred_strong = pred_strong.detach().numpy()
        if i == 0:
            logger.debug(pred_strong)

        pred_strong_sources = pred_strong[1:]
        pred_strong_sources = norm_alpha(pred_strong_sources, alpha)
        pred_strong_comb = norm_alpha(np.stack((pred_strong[0], pred_strong_sources), 0), alpha)

        # Get different post processing per threshold
        for threshold in thresholds:
            pred_strong_bin = ProbabilityEncoder().binarization(pred_strong_comb,
                                                                binarization_type="global_threshold",
                                                                threshold=threshold)
            pred_strong_m = scipy.ndimage.filters.median_filter(pred_strong_bin, (median_window, 1))
            pred = decoder(pred_strong_m)
            pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"])
            # Put them in seconds
            pred.loc[:, ["onset", "offset"]] *= pooling_time_ratio / (cfg.sample_rate / cfg.hop_size)
            pred.loc[:, ["onset", "offset"]] = pred[["onset", "offset"]].clip(0, cfg.max_len_seconds)

            pred["filename"] = valid_dataload.filenames.iloc[index]
            prediction_dfs[threshold] = prediction_dfs[threshold].append(pred, ignore_index=True)

            if i == 0:
                logger.debug("predictions: \n{}".format(pred))
                logger.debug("predictions strong: \n{}".format(pred_strong_comb))

    # Save predictions
    if save_predictions is not None:
        if isinstance(save_predictions, str):
            if len(thresholds) == 1:
                save_predictions = [save_predictions]
            else:
                base, ext = osp.splitext(save_predictions)
                save_predictions = [osp.join(base, f"{threshold:.3f}{ext}") for threshold in thresholds]
        else:
            assert len(save_predictions) == len(thresholds), \
                f"There should be a prediction file per threshold: len predictions: {len(save_predictions)}\n" \
                f"len thresholds: {len(thresholds)}"
            save_predictions = save_predictions

        for ind, threshold in enumerate(thresholds):
            dir_to_create = osp.dirname(save_predictions[ind])
            if dir_to_create != "":
                os.makedirs(dir_to_create, exist_ok=True)
                if ind % 10 == 0:
                    logger.info(f"Saving predictions at: {save_predictions[ind]}. {ind + 1} / {len(thresholds)}")
                prediction_dfs[threshold].to_csv(save_predictions[ind], index=False, sep="\t", float_format="%.3f")

    list_predictions = []
    for key in prediction_dfs:
        list_predictions.append(prediction_dfs[key])

    if len(list_predictions) == 1:
        list_predictions = list_predictions[0]

    return list_predictions
def get_predictions(encoder,
                    ema,
                    ms,
                    dataloader,
                    decoder,
                    pooling_time_ratio=1,
                    thresholds=[0.5],
                    median_window=1,
                    save_predictions=None):
    """ Get the predictions of a trained model on a specific set
    Args:
        model: torch.Module, a trained pytorch model (you usually want it to be in .eval() mode).
        dataloader: torch.utils.data.DataLoader, giving ((input_data, label), indexes) but label is not used here
        decoder: function, takes a numpy.array of shape (time_steps, n_labels) as input and return a list of lists
            of ("event_label", "onset", "offset") for each label predicted.
        pooling_time_ratio: the division to make between timesteps as input and timesteps as output
        median_window: int, the median window (in number of time steps) to be applied
        save_predictions: str or list, the path of the base_filename to save the predictions or a list of names
            corresponding for each thresholds
        thresholds: list, list of threshold to be applied

    Returns:
        dict of the different predictions with associated threshold
    """

    # Init a dataframe per threshold
    prediction_dfs = {}
    for threshold in thresholds:
        prediction_dfs[threshold] = pd.DataFrame()

    ratio = 0.5

    # Get predictions
    for i, ((input_data, _), indexes) in enumerate(dataloader):
        indexes = indexes.numpy()
        input_data = to_cuda_if_available(input_data)
        with torch.no_grad():
            enc_strong, enc_weak, feature = encoder(input_data)
            pred_strong, pred_weak = ms(feature)

        phi = generate_label(0.5 * (enc_weak + pred_weak)).unsqueeze(1)
        pred_strong = pred_strong * phi
        enc_strong = enc_strong.cpu()
        enc_strong = enc_strong.detach().numpy()
        pred_strong = pred_strong.cpu()
        pred_strong = pred_strong.detach().numpy()
        if i == 0:
            logger.debug(pred_strong)

        # Post processing and put predictions in a dataframe
        for j, pred_strong_it in enumerate(pred_strong):
            for threshold in thresholds:
                # enc_strong_it = adaptive_median_filter(enc_strong[j], decoder)
                pred_strong_it = adaptive_median_filter(
                    pred_strong_it, decoder)
                pred_strong_bin = ProbabilityEncoder().binarization(
                    pred_strong_it,
                    binarization_type="global_threshold",
                    threshold=threshold)
                # enc_strong_bin = ProbabilityEncoder().binarization(enc_strong_it,
                #                                                     binarization_type="global_threshold",
                #                                                     threshold=threshold)

                pred_strong_m = adaptive_median_filter(pred_strong_bin,
                                                       decoder)
                # enc_strong_m = adaptive_median_filter(enc_strong_bin, decoder)
                pred = decoder(pred_strong_m)
                pred = pd.DataFrame(pred,
                                    columns=["event_label", "onset", "offset"])
                # Put them in seconds
                pred.loc[:, ["onset", "offset"]] *= pooling_time_ratio / (
                    cfg.sample_rate / cfg.hop_size)
                pred.loc[:,
                         ["onset", "offset"]] = pred[["onset", "offset"]].clip(
                             0, cfg.max_len_seconds)

                pred["filename"] = dataloader.dataset.filenames.iloc[
                    indexes[j]]
                prediction_dfs[threshold] = prediction_dfs[threshold].append(
                    pred, ignore_index=True)

                if i == 0 and j == 0:
                    logger.debug("predictions: \n{}".format(pred))
                    logger.debug(
                        "predictions strong: \n{}".format(pred_strong_it))

    # Save predictions
    if save_predictions is not None:
        if isinstance(save_predictions, str):
            if len(thresholds) == 1:
                save_predictions = [save_predictions]
            else:
                base, ext = osp.splitext(save_predictions)
                save_predictions = [
                    osp.join(base, f"{threshold:.3f}{ext}")
                    for threshold in thresholds
                ]
        else:
            assert len(save_predictions) == len(thresholds), \
                f"There should be a prediction file per threshold: len predictions: {len(save_predictions)}\n" \
                f"len thresholds: {len(thresholds)}"
            save_predictions = save_predictions

        for ind, threshold in enumerate(thresholds):
            dir_to_create = osp.dirname(save_predictions[ind])
            if dir_to_create != "":
                os.makedirs(dir_to_create, exist_ok=True)
                if ind % 10 == 0:
                    logger.info(
                        f"Saving predictions at: {save_predictions[ind]}. {ind + 1} / {len(thresholds)}"
                    )
                prediction_dfs[threshold].to_csv(save_predictions[ind],
                                                 index=False,
                                                 sep="\t",
                                                 float_format="%.3f")

    list_predictions = []
    for key in prediction_dfs:
        list_predictions.append(prediction_dfs[key])

    if len(list_predictions) == 1:
        list_predictions = list_predictions[0]

    return list_predictions
Esempio n. 15
0
    }
    

    save_best_cb = SaveBest("sup")
    if cfg.early_stopping is not None:
        early_stopping_call = EarlyStopping(patience=cfg.early_stopping, val_comp="sup", init_patience=cfg.es_init_wait)

    # ##############
    # Train
    # ##############
    results = pd.DataFrame(columns=["loss", "valid_synth_f1", "weak_metric", "global_valid"])
    for epoch in range(cfg.n_epoch):
        MS_model.train()
        MS_ema.train()
        PS_model.train()
        MS_model, MS_ema, PS_model = to_cuda_if_available(MS_model, MS_ema, PS_model)

        loss_value = train(training_loader, MS_model, PS_model, ms_optim, ps_optim, epoch,
                           ema_model=MS_ema, mask_weak=weak_mask, mask_strong=strong_mask, adjust_lr=cfg.adjust_lr)

        # Validation
        ema = MS_ema.eval()
        MS_m = MS_model.eval()
        PS_m = PS_model.eval()
        logger.info("\n ### Valid synthetic metric ### \n")
        predictions = get_predictions(PS_m, ema, MS_m, valid_synth_loader, many_hot_encoder.decode_strong, pooling_time_ratio,
                                      median_window=median_window, save_predictions=None)
        # Validation with synthetic data (dropping feature_filename for psds)
        valid_synth = dfs["validation"].drop("feature_filename", axis=1)
        valid_synth_f1, psds_m_f1 = compute_metrics(predictions, valid_synth, durations_validation)