def _load_model(state, model_name="PS"): ps_args = state[model_name]["args"] ps_kwargs = state[model_name]["kwargs"] ps = PS(*ps_args, **ps_kwargs) ps.load_state_dict(state[model_name]["state_dict"]) ps.eval() ps = to_cuda_if_available(ps) ms_args = state["MS"]["args"] ms_kwargs = state["MS"]["kwargs"] ms = MS() ms.load_state_dict(state["MS"]["state_dict"]) ms.eval() ms = to_cuda_if_available(ms) ema_args = state["model_ema"]["args"] ema_kwargs = state["model_ema"]["kwargs"] ema = MS() ema.load_state_dict(state["model_ema"]["state_dict"]) ema.eval() ema = to_cuda_if_available(ema) logger.info("Model loaded at epoch: {}".format(state["epoch"])) logger.info(ps) return ps, ms, ema
def forward(self, x): # input size : (batch_size, n_channels, n_frames, n_freq) # conv features x = self.cnn(x) bs, chan, frames, freq = x.size() if freq != 1: warnings.warn( f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq" ) x = x.permute(0, 2, 1, 3) x = x.contiguous().view(bs, frames, chan * freq) else: x = x.squeeze(-1) x = x.permute(0, 2, 1) # [bs, frames, chan] weak = torch.FloatTensor([]) weak = to_cuda_if_available(weak) if self.attention: for i in range(self.nclass): # attention weight sof = torch.sum(x * self.weights[i], dim=-1) + self.bias[i] # [bs, frames, nclass] sof = sof / 256 sof = self.softmax(sof) # contexual representation cr = torch.matmul(sof.unsqueeze(1), x).squeeze(1) # audio tagging at = self.clf[i](cr) at = nn.Sigmoid()(at) weak = torch.cat((weak, at), dim=-1) return weak
def forward(self, x): # input size : (batch_size, n_channels, n_frames, n_freq) # conv features x = self.cnn(x) bs, chan, frames, freq = x.size() if freq != 1: warnings.warn( f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq" ) x = x.permute(0, 2, 1, 3) x = x.contiguous().view(bs, frames, chan * freq) else: x = x.squeeze(-1) x = x.permute(0, 2, 1) # [bs, frames, chan] cnn_x = x weak = torch.FloatTensor([]) weak = to_cuda_if_available(weak) strong = torch.FloatTensor([]) strong = to_cuda_if_available(strong) # ATP if self.attention: for i in range(self.nclass): x_c = x[:, :, :self.df[i]] # attention weight sof = torch.sum(x_c * self.weights[i], dim=-1) + self.bias[i] # [bs, frames, nclass] sed = self.sigmoid(sof).unsqueeze(-1) strong = torch.cat((strong, sed), dim=-1) sof = sof / self.df[i] sof = self.softmax(sof) # contexual representation cr = torch.matmul(sof.unsqueeze(1), x_c).squeeze(1) # audio tagging at = self.clf[i](cr) at = nn.Sigmoid()(at) weak = torch.cat((weak, at), dim=-1) phi = generate_label(weak).unsqueeze(1) strong = strong * phi return strong, weak, cnn_x
def _load_crnn(state, model_name="model"): crnn_args = state[model_name]["args"] crnn_kwargs = state[model_name]["kwargs"] crnn = CRNN(*crnn_args, **crnn_kwargs) crnn.load_state_dict(state[model_name]["state_dict"]) crnn.eval() crnn = to_cuda_if_available(crnn) logger.info("Model loaded at epoch: {}".format(state["epoch"])) logger.info(crnn) return crnn
def generate_label(pred): p = [] for b in range(pred.size(0)): pred_bin = ProbabilityEncoder().binarization(pred[b].detach().cpu().numpy(), binarization_type="global_threshold", threshold=0.5) p.append(pred_bin) p = torch.FloatTensor(p) p = to_cuda_if_available(p) return p
def get_predictions(model, valid_dataloader, decoder, pooling_time_ratio=1, median_window=1, save_predictions=None): prediction_df = pd.DataFrame() for i, ((input_data, _), indexes) in enumerate(valid_dataloader): indexes = indexes.numpy() input_data = to_cuda_if_available(input_data) pred_strong, _ = model(input_data) pred_strong = pred_strong.cpu() pred_strong = pred_strong.detach().numpy() if i == 0: logger.debug(pred_strong) for j, pred_strong_it in enumerate(pred_strong): pred_strong_it = ProbabilityEncoder().binarization( pred_strong_it, binarization_type="global_threshold", threshold=0.5) pred_strong_it = scipy.ndimage.filters.median_filter( pred_strong_it, (median_window, 1)) pred = decoder(pred_strong_it) pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"]) pred["filename"] = valid_dataloader.dataset.filenames.iloc[ indexes[j]] prediction_df = prediction_df.append(pred) if i == 0 and j == 0: logger.debug("predictions: \n{}".format(pred)) logger.debug("predictions strong: \n{}".format(pred_strong_it)) # In seconds prediction_df.loc[:, "onset"] = prediction_df.onset * pooling_time_ratio / ( cfg.sample_rate / cfg.hop_size) prediction_df.loc[:, "offset"] = prediction_df.offset * pooling_time_ratio / ( cfg.sample_rate / cfg.hop_size) prediction_df = prediction_df.reset_index(drop=True) if save_predictions is not None: dir_to_create = osp.dirname(save_predictions) if dir_to_create != "": os.makedirs(dir_to_create, exist_ok=True) logger.info("Saving predictions at: {}".format(save_predictions)) prediction_df.to_csv(save_predictions, index=False, sep="\t", float_format="%.3f") return prediction_df
def generate_label(pred): p = [] thresh = [0.45, 0.5, 0.5, 0.5, 0.5, 0.5, 0.45, 0.45, 0.5, 0.35] for b in range(pred.size(0)): pred_bin = np.array([]) for c in range(10): pred_b = ProbabilityEncoder().binarization( pred[b][c].unsqueeze(0).cpu().detach().numpy(), binarization_type="global_threshold", threshold=thresh[c]) pred_bin = np.concatenate((pred_bin, pred_b), axis=0) p.append(pred_bin) p = torch.FloatTensor(p) p = to_cuda_if_available(p) return p
def train(train_loader, model, optimizer, c_epoch, ema_model=None, mask_weak=None, mask_strong=None, adjust_lr=False): """ One epoch of a Mean Teacher model Args: train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch. Should return a tuple: ((teacher input, student input), labels) model: torch.Module, model to be trained, should return a weak and strong prediction optimizer: torch.Module, optimizer used to train the model c_epoch: int, the current epoch of training ema_model: torch.Module, student model, should return a weak and strong prediction mask_weak: slice or list, mask the batch to get only the weak labeled data (used to calculate the loss) mask_strong: slice or list, mask the batch to get only the strong labeled data (used to calcultate the loss) adjust_lr: bool, Whether or not to adjust the learning rate during training (params in config) """ log = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name, terminal_level=cfg.terminal_level) class_criterion = nn.BCELoss() consistency_criterion = nn.MSELoss() class_criterion, consistency_criterion = to_cuda_if_available( class_criterion, consistency_criterion) meters = AverageMeterSet() log.debug("Nb batches: {}".format(len(train_loader))) start = time.time() for i, ((batch_input, ema_batch_input), target) in enumerate(train_loader): global_step = c_epoch * len(train_loader) + i rampup_value = ramps.exp_rampup(global_step, cfg.n_epoch_rampup * len(train_loader)) if adjust_lr: adjust_learning_rate(optimizer, rampup_value) meters.update('lr', optimizer.param_groups[0]['lr']) batch_input, ema_batch_input, target = to_cuda_if_available( batch_input, ema_batch_input, target) # Outputs strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input) strong_pred_ema = strong_pred_ema.detach() weak_pred_ema = weak_pred_ema.detach() strong_pred, weak_pred = model(batch_input) loss = None # Weak BCE Loss target_weak = target.max(-2)[0] # Take the max in the time axis if mask_weak is not None: weak_class_loss = class_criterion(weak_pred[mask_weak], target_weak[mask_weak]) ema_class_loss = class_criterion(weak_pred_ema[mask_weak], target_weak[mask_weak]) loss = weak_class_loss if i == 0: log.debug( f"target: {target.mean(-2)} \n Target_weak: {target_weak} \n " f"Target weak mask: {target_weak[mask_weak]} \n " f"Target strong mask: {target[mask_strong].sum(-2)}\n" f"weak loss: {weak_class_loss} \t rampup_value: {rampup_value}" f"tensor mean: {batch_input.mean()}") meters.update('weak_class_loss', weak_class_loss.item()) meters.update('Weak EMA loss', ema_class_loss.item()) # Strong BCE loss if mask_strong is not None: strong_class_loss = class_criterion(strong_pred[mask_strong], target[mask_strong]) meters.update('Strong loss', strong_class_loss.item()) strong_ema_class_loss = class_criterion( strong_pred_ema[mask_strong], target[mask_strong]) meters.update('Strong EMA loss', strong_ema_class_loss.item()) if loss is not None: loss += strong_class_loss else: loss = strong_class_loss # Teacher-student consistency cost if ema_model is not None: consistency_cost = cfg.max_consistency_cost * rampup_value meters.update('Consistency weight', consistency_cost) # Take consistency about strong predictions (all data) consistency_loss_strong = consistency_cost * consistency_criterion( strong_pred, strong_pred_ema) meters.update('Consistency strong', consistency_loss_strong.item()) if loss is not None: loss += consistency_loss_strong else: loss = consistency_loss_strong meters.update('Consistency weight', consistency_cost) # Take consistency about weak predictions (all data) consistency_loss_weak = consistency_cost * consistency_criterion( weak_pred, weak_pred_ema) meters.update('Consistency weak', consistency_loss_weak.item()) if loss is not None: loss += consistency_loss_weak else: loss = consistency_loss_weak assert not (np.isnan(loss.item()) or loss.item() > 1e5), 'Loss explosion: {}'.format( loss.item()) assert not loss.item() < 0, 'Loss problem, cannot be negative' meters.update('Loss', loss.item()) # compute gradient and do optimizer step optimizer.zero_grad() loss.backward() optimizer.step() global_step += 1 if ema_model is not None: update_ema_variables(model, ema_model, 0.999, global_step) epoch_time = time.time() - start log.info(f"Epoch: {c_epoch}\t Time {epoch_time:.2f}\t {meters}") return loss, meters
save_best_cb = SaveBest("sup") if cfg.early_stopping is not None: early_stopping_call = EarlyStopping(patience=cfg.early_stopping, val_comp="sup", init_patience=cfg.es_init_wait) # ############## # Train # ############## results = pd.DataFrame( columns=["loss", "valid_synth_f1", "weak_metric", "global_valid"]) for epoch in range(cfg.n_epoch): crnn.train() crnn_ema.train() crnn, crnn_ema = to_cuda_if_available(crnn, crnn_ema) loss_value, meters = train(training_loader, crnn, optim, epoch, ema_model=crnn_ema, mask_weak=weak_mask, mask_strong=strong_mask, adjust_lr=cfg.adjust_lr) # Validation crnn = crnn.eval() logger.info("\n ### Valid synthetic metric ### \n") predictions = get_predictions(crnn, valid_synth_loader,
if cfg.early_stopping is not None: early_stopping_call = EarlyStopping(patience=cfg.early_stopping, val_comp="sup", init_patience=cfg.es_init_wait) # in case of not using ema_model wcrnn_ema = None # ############## # Train # ############## results = pd.DataFrame(columns=["loss", "valid_synth_f1", "weak_metric", "global_valid"]) for epoch in range(cfg.n_epoch): wcrnn.train() #wcrnn_ema.train() #wcrnn, wcrnn_ema = to_cuda_if_available(wcrnn, wcrnn_ema) wcrnn = to_cuda_if_available(wcrnn) loss_value = train(training_loader, wcrnn, optim, epoch, ema_model=None, mask_weak=weak_mask, mask_strong=strong_mask, adjust_lr=cfg.adjust_lr) #ema_model=wcrnn_ema, mask_weak=weak_mask, mask_strong=strong_mask, adjust_lr=cfg.adjust_lr) # Validation wcrnn = wcrnn.eval() logger.info("\n ### Valid synthetic metric ### \n") predictions = get_predictions(wcrnn, valid_synth_loader, many_hot_encoder.decode_strong, pooling_time_ratio, median_window=median_window, save_predictions=None) # Validation with synthetic data (dropping feature_filename for psds) valid_synth = dfs["valid_synthetic"].drop("feature_filename", axis=1) valid_synth_f1, psds_m_f1 = compute_metrics(predictions, valid_synth, durations_synth) # Update state state['model']['state_dict'] = wcrnn.state_dict()
def get_predictions(model, dataloader, decoder, pooling_time_ratio=1, thresholds=[0.5], median_window=1, save_predictions=None): """ Get the predictions of a trained model on a specific set Args: model: torch.Module, a trained pytorch model (you usually want it to be in .eval() mode). dataloader: torch.utils.data.DataLoader, giving ((input_data, label), indexes) but label is not used here decoder: function, takes a numpy.array of shape (time_steps, n_labels) as input and return a list of lists of ("event_label", "onset", "offset") for each label predicted. pooling_time_ratio: the division to make between timesteps as input and timesteps as output median_window: int, the median window (in number of time steps) to be applied save_predictions: str or list, the path of the base_filename to save the predictions or a list of names corresponding for each thresholds thresholds: list, list of threshold to be applied Returns: dict of the different predictions with associated threshold """ # Init a dataframe per threshold prediction_dfs = {} for threshold in thresholds: # print('threshold', threshold) prediction_dfs[threshold] = pd.DataFrame() # Transitions transitions = get_transitions([0.8936181121106661, 0.31550391085547114, 0.8923570974604905, 0.6490908582453855, 0.949481580049081, 0.7288818269203898, 0.4234812116634173, 0.5027625909657499, 0.8155335690232651, 0.6838988940188258, 0.6928248693985566]) # Get predictions for i, ((input_data, _), indexes) in enumerate(dataloader): indexes = indexes.numpy() input_data = to_cuda_if_available(input_data) with torch.no_grad(): pred_strong, _ = model(input_data) pred_strong = pred_strong.cpu() pred_strong = pred_strong.detach().numpy() if i == 0: logger.debug(pred_strong) # Post processing and put predictions in a dataframe for j, pred_strong_it in enumerate(pred_strong): for threshold in thresholds: # pred_strong_bin = ProbabilityEncoder().binarization(pred_strong_it, # binarization_type="global_threshold", # threshold=threshold) pred_strong_bin = apply_HMMs(pred_strong_it, transitions) # print("PRED_STRONG:", pred_strong_bin) # pred_strong_m = scipy.ndimage.filters.median_filter(pred_strong_bin, (median_window, 1)) pred_strong_m = pred_strong_bin pred = decoder(pred_strong_m) pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"]) # Put them in seconds pred.loc[:, ["onset", "offset"]] *= pooling_time_ratio / (cfg.sample_rate / cfg.hop_size) pred.loc[:, ["onset", "offset"]] = pred[["onset", "offset"]].clip(0, cfg.max_len_seconds) pred["filename"] = dataloader.dataset.filenames.iloc[indexes[j]] prediction_dfs[threshold] = prediction_dfs[threshold].append(pred, ignore_index=True) if i == 0 and j == 0: logger.debug("predictions: \n{}".format(pred)) logger.debug("predictions strong: \n{}".format(pred_strong_it)) # Save predictions if save_predictions is not None: if isinstance(save_predictions, str): if len(thresholds) == 1: save_predictions = [save_predictions] else: base, ext = osp.splitext(save_predictions) save_predictions = [osp.join(base, f"{threshold:.3f}{ext}") for threshold in thresholds] else: assert len(save_predictions) == len(thresholds), \ f"There should be a prediction file per threshold: len predictions: {len(save_predictions)}\n" \ f"len thresholds: {len(thresholds)}" save_predictions = save_predictions for ind, threshold in enumerate(thresholds): dir_to_create = osp.dirname(save_predictions[ind]) if dir_to_create != "": os.makedirs(dir_to_create, exist_ok=True) if ind % 10 == 0: logger.info(f"Saving predictions at: {save_predictions[ind]}. {ind + 1} / {len(thresholds)}") prediction_dfs[threshold].to_csv(save_predictions[ind], index=False, sep="\t", float_format="%.3f") list_predictions = [] for key in prediction_dfs: list_predictions.append(prediction_dfs[key]) if len(list_predictions) == 1: list_predictions = list_predictions[0] return list_predictions
save_best_cb = SaveBest("sup") if cfg.early_stopping is not None: early_stopping_call = EarlyStopping(patience=cfg.early_stopping, val_comp="sup", init_patience=cfg.es_init_wait) # ############## # Train # ############## results = pd.DataFrame( columns=["loss", "valid_synth_f1", "weak_metric", "global_valid"]) for epoch in range(cfg.n_epoch): wcrnn.train() wcrnn_ema.train() wcrnn, wcrnn_ema = to_cuda_if_available(wcrnn, wcrnn_ema) loss_value = train(training_loader, wcrnn, optim, epoch, ema_model=wcrnn_ema, mask_weak=weak_mask, mask_strong=strong_mask, adjust_lr=cfg.adjust_lr) # Validation wcrnn = wcrnn.eval() logger.info("\n ### Valid synthetic metric ### \n") predictions = get_predictions(wcrnn, valid_synth_loader, many_hot_encoder.decode_strong,
def get_predictions_ss_late_integration(model, valid_dataload, decoder, pooling_time_ratio=1, thresholds=[0.5], median_window=1, save_predictions=None, alpha=1): """ Get the predictions of a trained model on a specific set Args: model: torch.Module, a trained pytorch model (you usually want it to be in .eval() mode). valid_dataload: DataLoadDf, giving ((input_data, label), index) but label is not used here, the multiple data are the multiple sources (the mixture should always be the first one to appear, and then the sources) example: if the input data is (3, 1, timesteps, freq) there is the mixture and 2 sources. decoder: function, takes a numpy.array of shape (time_steps, n_labels) as input and return a list of lists of ("event_label", "onset", "offset") for each label predicted. pooling_time_ratio: the division to make between timesteps as input and timesteps as output median_window: int, the median window (in number of time steps) to be applied save_predictions: str or list, the path of the base_filename to save the predictions or a list of names corresponding for each thresholds thresholds: list, list of threshold to be applied alpha: float, the value of the norm to combine the predictions Returns: dict of the different predictions with associated threshold """ # Init a dataframe per threshold prediction_dfs = {} for threshold in thresholds: prediction_dfs[threshold] = pd.DataFrame() # Get predictions for i, ((input_data, _), index) in enumerate(valid_dataload): input_data = to_cuda_if_available(input_data) with torch.no_grad(): pred_strong, _ = model(input_data) pred_strong = pred_strong.cpu() pred_strong = pred_strong.detach().numpy() if i == 0: logger.debug(pred_strong) pred_strong_sources = pred_strong[1:] pred_strong_sources = norm_alpha(pred_strong_sources, alpha) pred_strong_comb = norm_alpha(np.stack((pred_strong[0], pred_strong_sources), 0), alpha) # Get different post processing per threshold for threshold in thresholds: pred_strong_bin = ProbabilityEncoder().binarization(pred_strong_comb, binarization_type="global_threshold", threshold=threshold) pred_strong_m = scipy.ndimage.filters.median_filter(pred_strong_bin, (median_window, 1)) pred = decoder(pred_strong_m) pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"]) # Put them in seconds pred.loc[:, ["onset", "offset"]] *= pooling_time_ratio / (cfg.sample_rate / cfg.hop_size) pred.loc[:, ["onset", "offset"]] = pred[["onset", "offset"]].clip(0, cfg.max_len_seconds) pred["filename"] = valid_dataload.filenames.iloc[index] prediction_dfs[threshold] = prediction_dfs[threshold].append(pred, ignore_index=True) if i == 0: logger.debug("predictions: \n{}".format(pred)) logger.debug("predictions strong: \n{}".format(pred_strong_comb)) # Save predictions if save_predictions is not None: if isinstance(save_predictions, str): if len(thresholds) == 1: save_predictions = [save_predictions] else: base, ext = osp.splitext(save_predictions) save_predictions = [osp.join(base, f"{threshold:.3f}{ext}") for threshold in thresholds] else: assert len(save_predictions) == len(thresholds), \ f"There should be a prediction file per threshold: len predictions: {len(save_predictions)}\n" \ f"len thresholds: {len(thresholds)}" save_predictions = save_predictions for ind, threshold in enumerate(thresholds): dir_to_create = osp.dirname(save_predictions[ind]) if dir_to_create != "": os.makedirs(dir_to_create, exist_ok=True) if ind % 10 == 0: logger.info(f"Saving predictions at: {save_predictions[ind]}. {ind + 1} / {len(thresholds)}") prediction_dfs[threshold].to_csv(save_predictions[ind], index=False, sep="\t", float_format="%.3f") list_predictions = [] for key in prediction_dfs: list_predictions.append(prediction_dfs[key]) if len(list_predictions) == 1: list_predictions = list_predictions[0] return list_predictions
def get_predictions(encoder, ema, ms, dataloader, decoder, pooling_time_ratio=1, thresholds=[0.5], median_window=1, save_predictions=None): """ Get the predictions of a trained model on a specific set Args: model: torch.Module, a trained pytorch model (you usually want it to be in .eval() mode). dataloader: torch.utils.data.DataLoader, giving ((input_data, label), indexes) but label is not used here decoder: function, takes a numpy.array of shape (time_steps, n_labels) as input and return a list of lists of ("event_label", "onset", "offset") for each label predicted. pooling_time_ratio: the division to make between timesteps as input and timesteps as output median_window: int, the median window (in number of time steps) to be applied save_predictions: str or list, the path of the base_filename to save the predictions or a list of names corresponding for each thresholds thresholds: list, list of threshold to be applied Returns: dict of the different predictions with associated threshold """ # Init a dataframe per threshold prediction_dfs = {} for threshold in thresholds: prediction_dfs[threshold] = pd.DataFrame() ratio = 0.5 # Get predictions for i, ((input_data, _), indexes) in enumerate(dataloader): indexes = indexes.numpy() input_data = to_cuda_if_available(input_data) with torch.no_grad(): enc_strong, enc_weak, feature = encoder(input_data) pred_strong, pred_weak = ms(feature) phi = generate_label(0.5 * (enc_weak + pred_weak)).unsqueeze(1) pred_strong = pred_strong * phi enc_strong = enc_strong.cpu() enc_strong = enc_strong.detach().numpy() pred_strong = pred_strong.cpu() pred_strong = pred_strong.detach().numpy() if i == 0: logger.debug(pred_strong) # Post processing and put predictions in a dataframe for j, pred_strong_it in enumerate(pred_strong): for threshold in thresholds: # enc_strong_it = adaptive_median_filter(enc_strong[j], decoder) pred_strong_it = adaptive_median_filter( pred_strong_it, decoder) pred_strong_bin = ProbabilityEncoder().binarization( pred_strong_it, binarization_type="global_threshold", threshold=threshold) # enc_strong_bin = ProbabilityEncoder().binarization(enc_strong_it, # binarization_type="global_threshold", # threshold=threshold) pred_strong_m = adaptive_median_filter(pred_strong_bin, decoder) # enc_strong_m = adaptive_median_filter(enc_strong_bin, decoder) pred = decoder(pred_strong_m) pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"]) # Put them in seconds pred.loc[:, ["onset", "offset"]] *= pooling_time_ratio / ( cfg.sample_rate / cfg.hop_size) pred.loc[:, ["onset", "offset"]] = pred[["onset", "offset"]].clip( 0, cfg.max_len_seconds) pred["filename"] = dataloader.dataset.filenames.iloc[ indexes[j]] prediction_dfs[threshold] = prediction_dfs[threshold].append( pred, ignore_index=True) if i == 0 and j == 0: logger.debug("predictions: \n{}".format(pred)) logger.debug( "predictions strong: \n{}".format(pred_strong_it)) # Save predictions if save_predictions is not None: if isinstance(save_predictions, str): if len(thresholds) == 1: save_predictions = [save_predictions] else: base, ext = osp.splitext(save_predictions) save_predictions = [ osp.join(base, f"{threshold:.3f}{ext}") for threshold in thresholds ] else: assert len(save_predictions) == len(thresholds), \ f"There should be a prediction file per threshold: len predictions: {len(save_predictions)}\n" \ f"len thresholds: {len(thresholds)}" save_predictions = save_predictions for ind, threshold in enumerate(thresholds): dir_to_create = osp.dirname(save_predictions[ind]) if dir_to_create != "": os.makedirs(dir_to_create, exist_ok=True) if ind % 10 == 0: logger.info( f"Saving predictions at: {save_predictions[ind]}. {ind + 1} / {len(thresholds)}" ) prediction_dfs[threshold].to_csv(save_predictions[ind], index=False, sep="\t", float_format="%.3f") list_predictions = [] for key in prediction_dfs: list_predictions.append(prediction_dfs[key]) if len(list_predictions) == 1: list_predictions = list_predictions[0] return list_predictions
} save_best_cb = SaveBest("sup") if cfg.early_stopping is not None: early_stopping_call = EarlyStopping(patience=cfg.early_stopping, val_comp="sup", init_patience=cfg.es_init_wait) # ############## # Train # ############## results = pd.DataFrame(columns=["loss", "valid_synth_f1", "weak_metric", "global_valid"]) for epoch in range(cfg.n_epoch): MS_model.train() MS_ema.train() PS_model.train() MS_model, MS_ema, PS_model = to_cuda_if_available(MS_model, MS_ema, PS_model) loss_value = train(training_loader, MS_model, PS_model, ms_optim, ps_optim, epoch, ema_model=MS_ema, mask_weak=weak_mask, mask_strong=strong_mask, adjust_lr=cfg.adjust_lr) # Validation ema = MS_ema.eval() MS_m = MS_model.eval() PS_m = PS_model.eval() logger.info("\n ### Valid synthetic metric ### \n") predictions = get_predictions(PS_m, ema, MS_m, valid_synth_loader, many_hot_encoder.decode_strong, pooling_time_ratio, median_window=median_window, save_predictions=None) # Validation with synthetic data (dropping feature_filename for psds) valid_synth = dfs["validation"].drop("feature_filename", axis=1) valid_synth_f1, psds_m_f1 = compute_metrics(predictions, valid_synth, durations_validation)