def prepare_data(self):
        # Dataset parameters

        # Creating dataset
        if self.hparams.dataset == "US8K":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
            }
            self.dataset = UrbanSound8K_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()

        elif self.hparams.dataset == "ESC50":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
            }
            self.dataset = ESC50_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()
        elif self.hparams.dataset == "SONYCUST":
            data_param = {
                "sonycust_folder": self.dataset_folder,
                "mode": "both",
                "cleaning_strat": "DCASE",
            }
            self.dataset = SONYCUST_TALNet(**data_param)
            self.train_dataset, self.val_dataset, _ = self.dataset.train_validation_test_split(
            )
        else:
            None
    def prepare_data(self):
        if self.data_prepared:
            return True

        transformation_list = [
            ShiftScaleRotate(shift_limit=0.1,
                             scale_limit=0.1,
                             rotate_limit=0.5),
            GridDistortion(),
            Cutout()
        ]
        albumentations_transform = Compose(transformation_list)

        # Dataset parameters
        data_param = {
            'mode': self.hparams.output_mode,
            'transform': albumentations_transform,
            'metadata': self.hparams.metadata,
            'one_hot_time': self.hparams.one_hot_time,
            'consensus_threshold': self.hparams.consensus_threshold,
            'cleaning_strat': self.hparams.cleaning_strat,
            'relabeled_name': self.hparams.relabeled_name,
        }

        # Creating dataset
        self.dataset = SONYCUST_TALNet(self.hparams.path_to_SONYCUST,
                                       **data_param)
        self.train_dataset, self.val_dataset, self.test_dataset = self.dataset.train_validation_test_split(
        )
        self.data_prepared = True
    def prepare_data(self):
        # Dataset parameters
        transformation_list = [
            ShiftScaleRotate(shift_limit=0.1,
                             scale_limit=0.1,
                             rotate_limit=0.5),
            GridDistortion(),
            Cutout(),
        ]
        albumentations_transform = Compose(transformation_list)

        # Creating dataset
        if self.hparams.dataset == "US8K":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
                "transform": albumentations_transform,
            }

            self.dataset = UrbanSound8K_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()

        elif self.hparams.dataset == "ESC50":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
                "transform": albumentations_transform,
            }
            self.dataset = ESC50_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()
        elif self.hparams.dataset == "SONYCUST":
            data_param = {
                "sonycust_folder": self.dataset_folder,
                "mode": "both",
                "cleaning_strat": "DCASE",
                "transform": albumentations_transform,
                "cleaning_strat": self.hparams.cleaning_strat,
                "relabeled_name": "best2.csv",
            }
            self.dataset = SONYCUST_TALNet(**data_param)
            self.train_dataset, self.val_dataset, _ = self.dataset.train_validation_test_split(
            )
        else:
            None
class DWSCClassifier(LightningModule):
    def __init__(self, hparams, fold):
        super().__init__()

        # Save hparams for later
        self.hparams = hparams
        self.fold = fold

        if self.hparams.dataset == "US8K":
            self.dataset_folder = config.path_to_UrbanSound8K
            self.nb_classes = 10
            self.best_scores = [0] * 5
        elif self.hparams.dataset == "ESC50":
            self.dataset_folder = config.path_to_ESC50
            self.nb_classes = 50
            self.best_scores = [0] * 5
        elif self.hparams.dataset == "SONYCUST":
            self.dataset_folder = config.path_to_SONYCUST
            self.nb_classes = 31
            self.best_scores = [0] * 10
        else:
            None

        #
        # Settings for the SED models
        model_param = {
            "cnn_channels": self.hparams.cnn_channels,
            "cnn_dropout": self.hparams.cnn_dropout,
            "dilated_output_channels": self.hparams.dilated_output_channels,
            "dilated_kernel_size":
            self.hparams.dilated_kernel_size,  # time, feature
            "dilated_stride": self.hparams.dilated_stride,  # time, feature
            "dilated_padding": self.hparams.dilated_padding,
            "dilation_shape": self.hparams.dilation_shape,
            "dilated_nb_features": 84,
            "nb_classes": self.nb_classes,
            "inner_kernel_size": 3,
            "inner_padding": 1,
        }

        self.model = DESSEDDilatedTag(**model_param)
        if self.hparams.dataset != "SONYCUST":
            self.loss = BCELoss(reduction="none")
        else:
            self.loss_c = BCELoss(reduction="none")
            self.loss_f = Masked_loss(BCELoss(reduction="none"))

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--cnn_channels", type=int, default=256)
        parser.add_argument("--cnn_dropout", type=float, default=0.25)
        parser.add_argument("--dilated_output_channels", type=int, default=256)
        parser.add_argument("--dilated_kernel_size",
                            nargs="+",
                            type=int,
                            default=[7, 7])
        parser.add_argument("--dilated_stride",
                            nargs="+",
                            type=int,
                            default=[1, 3])
        parser.add_argument("--dilated_padding",
                            nargs="+",
                            type=int,
                            default=[30, 0])
        parser.add_argument("--dilation_shape",
                            nargs="+",
                            type=int,
                            default=[10, 1])

        parser.add_argument(
            "--pooling",
            type=str,
            default="att",
            choices=["max", "ave", "lin", "exp", "att", "auto"],
        )

        parser.add_argument("--batch_size", type=int, default=24)
        parser.add_argument("--shuffle", type=bool, default=True)
        parser.add_argument("--init_lr", type=float, default=1e-3)

        parser.add_argument("--num_mels", type=int, default=64)
        parser.add_argument("--dataset",
                            type=str,
                            default="SONYCUST",
                            choices=["US8K", "ESC50", "SONYCUST"])

        return parser

    def forward(self, x):
        x = self.model(x)
        return x

    def prepare_data(self):
        # Dataset parameters

        # Creating dataset
        if self.hparams.dataset == "US8K":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
            }
            self.dataset = UrbanSound8K_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()

        elif self.hparams.dataset == "ESC50":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
            }
            self.dataset = ESC50_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()
        elif self.hparams.dataset == "SONYCUST":
            data_param = {
                "sonycust_folder": self.dataset_folder,
                "mode": "both",
                "cleaning_strat": "DCASE",
            }
            self.dataset = SONYCUST_TALNet(**data_param)
            self.train_dataset, self.val_dataset, _ = self.dataset.train_validation_test_split(
            )
        else:
            None

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.hparams.batch_size,
                          shuffle=self.hparams.shuffle,
                          num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.hparams.batch_size,
                          num_workers=4)

    def configure_optimizers(self):
        """
        optim_param = {
            'lr': self.hparams.init_lr
            }
        optimizer = Adam(self.model.parameters(), **optim_param)
        """
        base_optim_param = {"lr": self.hparams.init_lr}
        base_optim = Ralamb(self.model.parameters(), **base_optim_param)
        optim_param = {"k": 5, "alpha": 0.5}
        optimizer = Lookahead(base_optim, **optim_param)

        return optimizer

    def training_step(self, batch, batch_idx):
        if self.hparams.dataset != "SONYCUST":
            data, target = batch["input_vector"].float(), batch["label"].float(
            )
            output = self.forward(data)
            loss = self.loss(output, target).mean()
        else:
            data, target_c, target_f = (
                batch["input_vector"].float(),
                batch["label"]["coarse"].float(),
                batch["label"]["full_fine"].float(),
            )
            target = torch.cat([target_c, target_f], 1)
            output = self.forward(data)
            outputs_c, outputs_f = torch.split(output, [8, 23], 1)
            loss = torch.cat(
                [
                    self.loss_c(outputs_c, target_c).mean(0),
                    self.loss_f(outputs_f, target_f),
                ],
                0,
            ).mean()

        return {"loss": loss, "log": {"1_loss/train_loss": loss}}

    def validation_step(self, batch, batch_idx):
        if self.hparams.dataset != "SONYCUST":
            data, target = batch["input_vector"].float(), batch["label"].float(
            )
            output = self.forward(data)
            # Compute loss of the batch
            loss = self.loss(output, target)
        else:
            data, target_c, target_f = (
                batch["input_vector"].float(),
                batch["label"]["coarse"].float(),
                batch["label"]["full_fine"].float(),
            )
            target = torch.cat([target_c, target_f], 1)
            output = self.forward(data)
            outputs_c, outputs_f = torch.split(output, [8, 23], 1)
            # Compute loss of the batch
            loss = torch.cat(
                [
                    self.loss_c(outputs_c, target_c).mean(0),
                    self.loss_f(outputs_f, target_f),
                ],
                0,
            )

        return {
            "val_loss": loss,
            "output": output,
            "target": target,
        }

    def validation_epoch_end(self, outputs):
        val_loss = torch.cat([o["val_loss"] for o in outputs], 0).mean()
        all_outputs = torch.cat([o["output"] for o in outputs],
                                0).cpu().numpy()
        all_targets = torch.cat([o["target"] for o in outputs],
                                0).cpu().numpy()

        if self.hparams.dataset == "SONYCUST":
            # Logic for SONYCUST
            X_mask = ~torch.BoolTensor([
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
            ])
            outputs_split = np.split(all_outputs, [8, 31], 1)
            all_outputs_coarse, all_outputs_fine = outputs_split[
                0], outputs_split[1]

            all_targets = all_targets[:, X_mask]
            targets_split = np.split(all_targets, [8, 31], 1)
            all_targets_coarse, all_targets_fine = targets_split[
                0], targets_split[1]

            accuracy_c = accuracy(all_targets_coarse, all_outputs_coarse)
            f1_micro_c = compute_micro_F1(all_targets_coarse,
                                          all_outputs_coarse)
            auprc_micro_c = compute_micro_auprc(all_targets_coarse,
                                                all_outputs_coarse)
            auprc_macro_c = compute_macro_auprc(all_targets_coarse,
                                                all_outputs_coarse)
            map_coarse = mean_average_precision(all_targets_coarse,
                                                all_outputs_coarse)

            accuracy_f = accuracy(all_targets_fine, all_outputs_fine)
            f1_micro_f = compute_micro_F1(all_targets_fine, all_outputs_fine)
            auprc_micro_f = compute_micro_auprc(all_targets_fine,
                                                all_outputs_fine)
            auprc_macro_f = compute_macro_auprc(all_targets_fine,
                                                all_outputs_fine)
            map_fine = mean_average_precision(all_targets_fine,
                                              all_outputs_fine)

            if accuracy_c > self.best_scores[0]:
                self.best_scores[0] = accuracy_c
            if f1_micro_c > self.best_scores[1]:
                self.best_scores[1] = f1_micro_c
            if auprc_micro_c > self.best_scores[2]:
                self.best_scores[2] = auprc_micro_c
            if auprc_macro_c > self.best_scores[3]:
                self.best_scores[3] = auprc_macro_c
            if map_coarse > self.best_scores[4]:
                self.best_scores[4] = map_coarse

            if accuracy_f > self.best_scores[5]:
                self.best_scores[5] = accuracy_f
            if f1_micro_f > self.best_scores[6]:
                self.best_scores[6] = f1_micro_f
            if auprc_micro_f > self.best_scores[7]:
                self.best_scores[7] = auprc_micro_f
            if auprc_macro_f > self.best_scores[8]:
                self.best_scores[8] = auprc_macro_f
            if map_fine > self.best_scores[9]:
                self.best_scores[9] = map_fine

            log_temp = {
                "2_valid_coarse/[email protected]": accuracy_c,
                "2_valid_coarse/[email protected]": f1_micro_c,
                "2_valid_coarse/1_auprc_micro": auprc_micro_c,
                "2_valid_coarse/1_auprc_macro": auprc_macro_c,
                "2_valid_coarse/1_map_coarse": map_coarse,
                "3_valid_fine/[email protected]": accuracy_f,
                "3_valid_fine/[email protected]": f1_micro_f,
                "3_valid_fine/1_auprc_micro": auprc_micro_f,
                "3_valid_fine/1_auprc_macro": auprc_macro_f,
                "3_valid_fine/1_map_fine": map_fine,
            }

            tqdm_dict = {
                "val_loss": val_loss,
                "m_auprc_c": auprc_macro_c,
            }

        else:
            # Logic for ESC50 and US8K
            accuracy_score = accuracy(all_targets, all_outputs)
            f1_micro = compute_micro_F1(all_targets, all_outputs)
            auprc_micro = compute_micro_auprc(all_targets, all_outputs)
            _, auprc_macro = compute_macro_auprc(all_targets, all_outputs,
                                                 True)
            map_score = mean_average_precision(all_targets, all_outputs)

            if accuracy_score > self.best_scores[0]:
                self.best_scores[0] = accuracy_score
            if f1_micro > self.best_scores[1]:
                self.best_scores[1] = f1_micro
            if auprc_micro > self.best_scores[2]:
                self.best_scores[2] = auprc_micro
            if auprc_macro > self.best_scores[3]:
                self.best_scores[3] = auprc_macro
            if map_score > self.best_scores[4]:
                self.best_scores[4] = map_score

            log_temp = {
                "2_valid/1_accuracy0.5": accuracy_score,
                "2_valid/1_f1_micro0.5": f1_micro,
                "2_valid/1_auprc_micro": auprc_micro,
                "2_valid/1_auprc_macro": auprc_macro,
                "2_valid/1_map": map_score,
            }

            tqdm_dict = {
                "val_loss": val_loss,
                "acc": accuracy_score,
            }

        log = {
            "step": self.current_epoch,
            "1_loss/val_loss": val_loss,
        }

        log.update(log_temp)

        return {"progress_bar": tqdm_dict, "log": log}
class TALNetV3Classifier(LightningModule):
    def __init__(self, hparams, fold):
        super().__init__()

        # Save hparams for later
        self.hparams = hparams
        self.fold = fold

        if self.hparams.dataset == "US8K":
            self.dataset_folder = config.path_to_UrbanSound8K
            self.nb_classes = 10
            self.input_size = (162, 64)
            self.best_scores = [0] * 5
        elif self.hparams.dataset == "ESC50":
            self.dataset_folder = config.path_to_ESC50
            self.nb_classes = 50
            self.input_size = (200, 64)
            self.best_scores = [0] * 5
        elif self.hparams.dataset == "SONYCUST":
            self.dataset_folder = config.path_to_SONYCUST
            self.nb_classes = 31
            self.input_size = (400, 64)
            self.best_scores = [0] * 10
        else:
            None

        model_param = {
            "num_mels": hparams.num_mels,
            "num_classes": self.nb_classes,
        }
        self.model = TALNetV3NoMeta(hparams.__dict__, **model_param)

        # Load every pretrained layers matching our layers
        pretrained_dict = torch.load(config.audioset,
                                     map_location="cpu")["model"]
        model_dict = self.model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.size() == model_dict[k].size()
        }
        model_dict.update(pretrained_dict)
        self.model.load_state_dict(model_dict)

        if self.hparams.dataset != "SONYCUST":
            self.loss = BCELoss(reduction="none")
        else:
            self.loss_c = BCELoss(reduction="none")
            self.loss_f = Masked_loss(BCELoss(reduction="none"))

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--dropout", type=float, default=0.0)
        parser.add_argument(
            "--pooling",
            type=str,
            default="att",
            choices=["max", "ave", "lin", "exp", "att", "auto"],
        )
        parser.add_argument("--n_conv_layers", type=int, default=10)
        parser.add_argument("--kernel_size", type=str, default="3")
        parser.add_argument("--n_pool_layers", type=int, default=5)
        parser.add_argument("--embedding_size", type=int, default=1024)
        parser.add_argument("--batch_norm", type=bool, default=True)
        parser.add_argument(
            "--conv_activation",
            type=str,
            default="mish",
            choices=["relu", "prelu", "leaky_relu", "mish"],
        )
        parser.add_argument("--n_head", type=int, default=8)
        parser.add_argument("--d_kv", type=int, default=128)
        parser.add_argument("--dropout_transfo", type=float, default=0.2)
        # TalnetV3
        parser.add_argument("--transfo_head", type=int, default=16)
        parser.add_argument("--dropout_AS", type=float, default=0.0)

        parser.add_argument("--batch_size", type=int, default=52)
        parser.add_argument("--init_lr", type=float, default=1e-3)
        parser.add_argument("--weight_decay", type=float, default=0)  # 1e-5)
        parser.add_argument("--start_mixup", type=int, default=-1)

        parser.add_argument("--shuffle", type=bool, default=True)
        parser.add_argument("--num_mels", type=int, default=64)
        parser.add_argument("--dataset",
                            type=str,
                            default="SONYCUST",
                            choices=["US8K", "ESC50", "SONYCUST"])
        parser.add_argument("--cleaning_strat",
                            type=str,
                            default="DCASE",
                            choices=["DCASE", "Relabeled"])

        return parser

    def forward(self, x):
        # x = x.unsqueeze(1)
        x = self.model(x)
        return x

    def prepare_data(self):
        # Dataset parameters
        transformation_list = [
            ShiftScaleRotate(shift_limit=0.1,
                             scale_limit=0.1,
                             rotate_limit=0.5),
            GridDistortion(),
            Cutout(),
        ]
        albumentations_transform = Compose(transformation_list)

        # Creating dataset
        if self.hparams.dataset == "US8K":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
                "transform": albumentations_transform,
            }

            self.dataset = UrbanSound8K_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()

        elif self.hparams.dataset == "ESC50":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
                "transform": albumentations_transform,
            }
            self.dataset = ESC50_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()
        elif self.hparams.dataset == "SONYCUST":
            data_param = {
                "sonycust_folder": self.dataset_folder,
                "mode": "both",
                "cleaning_strat": "DCASE",
                "transform": albumentations_transform,
                "cleaning_strat": self.hparams.cleaning_strat,
                "relabeled_name": "best2.csv",
            }
            self.dataset = SONYCUST_TALNet(**data_param)
            self.train_dataset, self.val_dataset, _ = self.dataset.train_validation_test_split(
            )
        else:
            None

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.hparams.batch_size,
                          shuffle=self.hparams.shuffle,
                          num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.hparams.batch_size,
                          num_workers=4)

    def configure_optimizers(self):
        """
        optim_param = {
            'lr': self.hparams.init_lr
            }
        optimizer = Adam(self.model.parameters(), **optim_param)
        """
        base_optim_param = {"lr": self.hparams.init_lr}
        base_optim = Ralamb(self.model.parameters(), **base_optim_param)
        optim_param = {"k": 5, "alpha": 0.5}
        optimizer = Lookahead(base_optim, **optim_param)

        return optimizer

    def training_step(self, batch, batch_idx):
        if self.hparams.dataset != "SONYCUST":
            data, target = batch["input_vector"].float(), batch["label"].float(
            )
            output = self.forward(data)[0]
            loss = self.loss(output, target).mean()
        else:
            data, target_c, target_f = (
                batch["input_vector"].float(),
                batch["label"]["coarse"].float(),
                batch["label"]["full_fine"].float(),
            )
            target = torch.cat([target_c, target_f], 1)

            output = self.forward(data)[0]
            outputs_c, outputs_f = torch.split(output, [8, 23], 1)
            loss = torch.cat(
                [
                    self.loss_c(outputs_c, target_c).mean(0),
                    self.loss_f(outputs_f, target_f),
                ],
                0,
            ).mean()

        return {"loss": loss, "log": {"1_loss/train_loss": loss}}

    def validation_step(self, batch, batch_idx):
        if self.hparams.dataset != "SONYCUST":
            data, target = batch["input_vector"].float(), batch["label"].float(
            )
            output = self.forward(data)[0]
            # Compute loss of the batch
            loss = self.loss(output, target)
        else:
            data, target_c, target_f = (
                batch["input_vector"].float(),
                batch["label"]["coarse"].float(),
                batch["label"]["full_fine"].float(),
            )
            target = torch.cat([target_c, target_f], 1)
            output = self.forward(data)[0]
            outputs_c, outputs_f = torch.split(output, [8, 23], 1)
            # Compute loss of the batch
            loss = torch.cat(
                [
                    self.loss_c(outputs_c, target_c).mean(0),
                    self.loss_f(outputs_f, target_f),
                ],
                0,
            )

        return {
            "val_loss": loss,
            "output": output,
            "target": target,
        }

    def validation_epoch_end(self, outputs):
        val_loss = torch.cat([o["val_loss"] for o in outputs], 0).mean()
        all_outputs = torch.cat([o["output"] for o in outputs],
                                0).cpu().numpy()
        all_targets = torch.cat([o["target"] for o in outputs],
                                0).cpu().numpy()

        if self.hparams.dataset == "SONYCUST":
            # Logic for SONYCUST
            X_mask = ~torch.BoolTensor([
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
            ])
            outputs_split = np.split(all_outputs, [8, 31], 1)
            all_outputs_coarse, all_outputs_fine = outputs_split[
                0], outputs_split[1]

            all_targets = all_targets[:, X_mask]
            targets_split = np.split(all_targets, [8, 31], 1)
            all_targets_coarse, all_targets_fine = targets_split[
                0], targets_split[1]

            accuracy_c = accuracy(all_targets_coarse, all_outputs_coarse)
            f1_micro_c = compute_micro_F1(all_targets_coarse,
                                          all_outputs_coarse)
            auprc_micro_c = compute_micro_auprc(all_targets_coarse,
                                                all_outputs_coarse)
            auprc_macro_c = compute_macro_auprc(all_targets_coarse,
                                                all_outputs_coarse)
            map_coarse = mean_average_precision(all_targets_coarse,
                                                all_outputs_coarse)

            accuracy_f = accuracy(all_targets_fine, all_outputs_fine)
            f1_micro_f = compute_micro_F1(all_targets_fine, all_outputs_fine)
            auprc_micro_f = compute_micro_auprc(all_targets_fine,
                                                all_outputs_fine)
            auprc_macro_f = compute_macro_auprc(all_targets_fine,
                                                all_outputs_fine)
            map_fine = mean_average_precision(all_targets_fine,
                                              all_outputs_fine)

            if accuracy_c > self.best_scores[0]:
                self.best_scores[0] = accuracy_c
            if f1_micro_c > self.best_scores[1]:
                self.best_scores[1] = f1_micro_c
            if auprc_micro_c > self.best_scores[2]:
                self.best_scores[2] = auprc_micro_c
            if auprc_macro_c > self.best_scores[3]:
                self.best_scores[3] = auprc_macro_c
            if map_coarse > self.best_scores[4]:
                self.best_scores[4] = map_coarse

            if accuracy_f > self.best_scores[5]:
                self.best_scores[5] = accuracy_f
            if f1_micro_f > self.best_scores[6]:
                self.best_scores[6] = f1_micro_f
            if auprc_micro_f > self.best_scores[7]:
                self.best_scores[7] = auprc_micro_f
            if auprc_macro_f > self.best_scores[8]:
                self.best_scores[8] = auprc_macro_f
            if map_fine > self.best_scores[9]:
                self.best_scores[9] = map_fine

            log_temp = {
                "2_valid_coarse/[email protected]": accuracy_c,
                "2_valid_coarse/[email protected]": f1_micro_c,
                "2_valid_coarse/1_auprc_micro": auprc_micro_c,
                "2_valid_coarse/1_auprc_macro": auprc_macro_c,
                "2_valid_coarse/1_map_coarse": map_coarse,
                "3_valid_fine/[email protected]": accuracy_f,
                "3_valid_fine/[email protected]": f1_micro_f,
                "3_valid_fine/1_auprc_micro": auprc_micro_f,
                "3_valid_fine/1_auprc_macro": auprc_macro_f,
                "3_valid_fine/1_map_fine": map_fine,
            }

            tqdm_dict = {
                "val_loss": val_loss,
                "m_auprc_c": auprc_macro_c,
            }

        else:
            # Logic for ESC50 and US8K
            accuracy_score = accuracy(all_targets, all_outputs)
            f1_micro = compute_micro_F1(all_targets, all_outputs)
            auprc_micro = compute_micro_auprc(all_targets, all_outputs)
            _, auprc_macro = compute_macro_auprc(all_targets, all_outputs,
                                                 True)
            map_score = mean_average_precision(all_targets, all_outputs)

            if accuracy_score > self.best_scores[0]:
                self.best_scores[0] = accuracy_score
            if f1_micro > self.best_scores[1]:
                self.best_scores[1] = f1_micro
            if auprc_micro > self.best_scores[2]:
                self.best_scores[2] = auprc_micro
            if auprc_macro > self.best_scores[3]:
                self.best_scores[3] = auprc_macro
            if map_score > self.best_scores[4]:
                self.best_scores[4] = map_score

            log_temp = {
                "2_valid/1_accuracy0.5": accuracy_score,
                "2_valid/1_f1_micro0.5": f1_micro,
                "2_valid/1_auprc_micro": auprc_micro,
                "2_valid/1_auprc_macro": auprc_macro,
                "2_valid/1_map": map_score,
            }

            tqdm_dict = {
                "val_loss": val_loss,
                "acc": accuracy_score,
            }

        log = {
            "step": self.current_epoch,
            "1_loss/val_loss": val_loss,
        }

        log.update(log_temp)

        return {"progress_bar": tqdm_dict, "log": log}
Exemple #6
0
                    nargs='+',
                    default=["latitude", "longitude", "week", "day", "hour"])
parser.add_argument('--consensus_threshold', type=float, default=0.0)
parser.add_argument('--one_hot_time', type=bool, default=False)
args = parser.parse_args()

# Dataset parameters
data_param = {
    'mode': args.output_mode,
    'metadata': args.metadata,
    'one_hot_time': args.one_hot_time,
    'consensus_threshold': args.consensus_threshold
}

# Creating dataset
dataset = SONYCUST_TALNet(args.path_to_SONYCUST, **data_param)
train_dataset, valid_dataset, test_dataset = dataset.train_validation_test_split(
)

test_dataloader = DataLoader(test_dataset, batch_size=64, num_workers=4)

# Creating model
model = DCASETALNetClassifier.load_from_checkpoint(
    args.path_to_ckpt, hparams_file=args.path_to_yaml)
model.freeze()
model.to('cuda:0')
print("Number of parameters : ", count_parameters(model))

print("Computing outputs...")
for i_batch, sample_batched in enumerate(tqdm(test_dataloader), 1):
class DCASETALNetClassifier(LightningModule):
    def __init__(self, hparams):
        super().__init__()
        # Save hparams for later
        self.hparams = hparams
        self.path_to_talnet = config.audioset
        self.num_classes_dict = {'coarse': 8, 'fine': 23, 'both': 31}
        #hack
        self.data_prepared = False
        self.prepare_data()
        num_meta = len(self.dataset[0]['metadata'])

        # Model parameters
        model_param = dict(
            (k, self.hparams.__dict__[k])
            for k in ("dropout", "pooling", "n_conv_layers", "kernel_size",
                      "n_pool_layers", "embedding_size", "batch_norm"))
        self.model = TALNetV3(
            self.hparams,
            num_mels=64,
            num_meta=num_meta,
            num_classes=self.num_classes_dict[self.hparams.output_mode])

        # Load every pretrained layers matching our layers
        pretrained_dict = torch.load(self.path_to_talnet,
                                     map_location='cpu')['model']
        model_dict = self.model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.size() == model_dict[k].size()
        }
        model_dict.update(pretrained_dict)
        self.model.load_state_dict(model_dict)
        self.loss_c = BCELoss(reduction='none')
        self.loss_f = Masked_loss(BCELoss(reduction='none'))

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--output_mode',
                            type=str,
                            default='both',
                            choices=['coarse', 'fine', 'both'])
        parser.add_argument('--weighted_loss', type=bool, default=False)
        parser.add_argument('--dropout', type=float, default=0.0)
        parser.add_argument('--dropout_AS', type=float, default=0.0)
        parser.add_argument('--dropout_transfo', type=float, default=0.2)
        parser.add_argument(
            '--pooling',
            type=str,
            default='att',
            choices=['max', 'ave', 'lin', 'exp', 'att', 'auto'])
        parser.add_argument('--n_conv_layers', type=int, default=10)
        parser.add_argument('--kernel_size', type=str, default='3')
        parser.add_argument('--n_pool_layers', type=int, default=5)
        parser.add_argument('--embedding_size', type=int, default=1024)
        parser.add_argument('--transfo_head', type=int, default=16)
        parser.add_argument('--nb_meta_emb', type=int, default=64)
        parser.add_argument('--batch_norm', type=bool, default=True)

        parser.add_argument('--batch_size', type=int, default=50)
        parser.add_argument('--alpha', type=float, default=0.5)
        parser.add_argument('--init_lr', type=float, default=1e-3)
        parser.add_argument('--weight_decay', type=float, default=1e-5)

        parser.add_argument('--path_to_SONYCUST',
                            type=str,
                            default=config.path_to_SONYCUST)
        parser.add_argument(
            '--metadata',
            nargs='+',
            default=["latitude", "longitude", "week", "day", "hour"])
        parser.add_argument('--consensus_threshold', type=float, default=0.0)
        parser.add_argument('--one_hot_time', type=bool, default=False)

        parser.add_argument('--cleaning_strat',
                            type=str,
                            default='DCASE',
                            choices=['DCASE', 'Relabeled'])
        parser.add_argument('--relabeled_name', type=str, default='best2.csv')

        return parser

    def forward(self, x, meta):
        x = self.model(x, meta)
        return x

    def prepare_data(self):
        if self.data_prepared:
            return True

        transformation_list = [
            ShiftScaleRotate(shift_limit=0.1,
                             scale_limit=0.1,
                             rotate_limit=0.5),
            GridDistortion(),
            Cutout()
        ]
        albumentations_transform = Compose(transformation_list)

        # Dataset parameters
        data_param = {
            'mode': self.hparams.output_mode,
            'transform': albumentations_transform,
            'metadata': self.hparams.metadata,
            'one_hot_time': self.hparams.one_hot_time,
            'consensus_threshold': self.hparams.consensus_threshold,
            'cleaning_strat': self.hparams.cleaning_strat,
            'relabeled_name': self.hparams.relabeled_name,
        }

        # Creating dataset
        self.dataset = SONYCUST_TALNet(self.hparams.path_to_SONYCUST,
                                       **data_param)
        self.train_dataset, self.val_dataset, self.test_dataset = self.dataset.train_validation_test_split(
        )
        self.data_prepared = True

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.hparams.batch_size,
                          shuffle=True,
                          num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.hparams.batch_size,
                          num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.hparams.batch_size,
                          num_workers=4)

    def configure_optimizers(self):
        base_optim_param = {'lr': self.hparams.init_lr}
        base_optim = Ralamb(self.model.parameters(), **base_optim_param)
        optim_param = {'k': 5, 'alpha': 0.5}
        optimizer = Lookahead(base_optim, **optim_param)
        return optimizer

    def training_step(self, batch, batch_idx):
        # Get input vector and labels, do not forget float()
        data, meta, target_c, target_f = batch['input_vector'].float(
        ), batch['metadata'].float(), batch['label']['coarse'].float(
        ), batch['label']['full_fine'].float()
        target = torch.cat([target_c, target_f], 1)

        # Forward pass
        output = self.forward(data, meta)[0]
        outputs_c, outputs_f = torch.split(output, [8, 23], 1)
        # Compute loss of the batch
        loss = torch.cat([
            self.loss_c(outputs_c, target_c).mean(0),
            self.loss_f(outputs_f, target_f)
        ], 0).mean()

        return {'loss': loss, 'log': {'1_loss/train_loss': loss}}

    def validation_step(self, batch, batch_idx):
        # Get input vector and labels, do not forget float()
        filename, data, meta, target_c, target_f = batch['file_name'], batch[
            'input_vector'].float(), batch['metadata'].float(), batch['label'][
                'coarse'].float(), batch['label']['full_fine'].float()
        target = torch.cat([target_c, target_f], 1)
        # Forward pass
        output = self.forward(data, meta)[0]
        outputs_c, outputs_f = torch.split(output, [8, 23], 1)
        # Compute loss of the batch
        loss = torch.cat([
            self.loss_c(outputs_c, target_c).mean(0),
            self.loss_f(outputs_f, target_f)
        ], 0)
        return {
            'val_loss': loss,
            'output': output,
            'target': target,
            'filename': filename
        }

    def validation_epoch_end(self, outputs):
        val_loss = torch.cat([o['val_loss'] for o in outputs], 0).mean()
        all_outputs = torch.cat([o['output'] for o in outputs], 0)
        all_outputs_c, all_outputs_f = torch.split(all_outputs, [8, 23], 1)
        all_outputs_c, all_outputs_f = all_outputs_c.cpu().numpy(
        ), all_outputs_f.cpu().numpy()

        all_targets = torch.cat([o['target'] for o in outputs], 0)
        all_targets_c, all_targets_f = torch.split(all_targets, [8, 29], 1)
        all_targets_c, all_targets_f = all_targets_c.cpu().numpy(
        ), all_targets_f.cpu().numpy()

        filename_array = [f for o in outputs for f in o['filename']]

        # Trick to use DCASE metrics : we save a csv file
        pred_df = pd.DataFrame(columns=['audio_filename'] +
                               self.dataset.idlabel_dict['coarse'] +
                               self.dataset.idlabel_dict['full_fine'])
        pred_df['audio_filename'] = filename_array
        pred_df[self.dataset.idlabel_dict['coarse']] = all_outputs_c
        pred_df[self.dataset.idlabel_dict['fine']] = all_outputs_f
        pred_df.to_csv(os.path.join(config.path_to_SONYCUST, "temp.csv"),
                       index=False,
                       header=True)

        log = {'1_loss/val_loss': val_loss}
        for mode in ['coarse', 'fine']:
            df_dict = evaluate(
                os.path.join(config.path_to_SONYCUST, "temp.csv"),
                config.path_to_annotation, config.path_to_taxonomy, mode)
            if mode == 'coarse':
                auprc_micro_c, eval_df = micro_averaged_auprc(df_dict,
                                                              return_df=True)
                auprc_macro_c, auprc_classes = macro_averaged_auprc(
                    df_dict, return_classwise=True)
                thresh_0pt5_idx = (eval_df['threshold'] >=
                                   0.5).to_numpy().nonzero()[0][0]
                F1_micro_c = eval_df["F"][thresh_0pt5_idx]
                log_temp = {
                    '2_valid_coarse/1_auprc_macro': auprc_macro_c,
                    '2_valid_coarse/2_auprc_micro': auprc_micro_c,
                    '2_valid_coarse/3_F1_micro': F1_micro_c
                }
            else:
                auprc_micro_f, eval_df = micro_averaged_auprc(df_dict,
                                                              return_df=True)
                auprc_macro_f, auprc_classes = macro_averaged_auprc(
                    df_dict, return_classwise=True)
                thresh_0pt5_idx = (eval_df['threshold'] >=
                                   0.5).to_numpy().nonzero()[0][0]
                F1_micro_f = eval_df["F"][thresh_0pt5_idx]
                log_temp = {
                    '2_valid_fine/1_auprc_macro': auprc_macro_f,
                    '2_valid_fine/2_auprc_micro': auprc_micro_f,
                    '2_valid_fine/3_F1_micro': F1_micro_f
                }
            log.update(log_temp)

        tqdm_dict = {
            'auprc_macro_c': auprc_macro_c,
            'auprc_macro_f': auprc_macro_f,
            'val_loss': val_loss
        }
        return {'progress_bar': tqdm_dict, 'log': log}
Exemple #8
0
                    default=["latitude", "longitude", "week", "day", "hour"])
parser.add_argument('--consensus_threshold', type=float, default=0.0)
parser.add_argument('--one_hot_time', type=bool, default=False)
args = parser.parse_args()

# Dataset parameters
data_param = {
    'mode': args.output_mode,
    'metadata': args.metadata,
    'one_hot_time': args.one_hot_time,
    'consensus_threshold': args.consensus_threshold,
    'cleaning_strat': 'All_unique'
}

# Creating dataset
dataset = SONYCUST_TALNet(args.path_to_SONYCUST, **data_param)

test_dataloader = DataLoader(dataset, batch_size=64, num_workers=4)

# Creating model
model = DCASETALNetClassifier.load_from_checkpoint(
    args.path_to_ckpt, hparams_file=args.path_to_yaml)
model.freeze()
model.to('cuda:0')

print("Computing new labels...")
for i_batch, sample_batched in enumerate(tqdm(test_dataloader), 1):

    filenames = sample_batched['file_name']
    inputs = sample_batched['input_vector'].float().cuda()
    metas = sample_batched['metadata'].float().cuda()
Exemple #9
0
class CNN10Classifier(LightningModule):
    def __init__(self, hparams, fold):
        super().__init__()

        # Save hparams for later
        self.hparams = hparams
        self.fold = fold

        if self.hparams.dataset == "US8K":
            self.dataset_folder = config.path_to_UrbanSound8K
            self.nb_classes = 10
            self.best_scores = [0] * 5
        elif self.hparams.dataset == "ESC50":
            self.dataset_folder = config.path_to_ESC50
            self.nb_classes = 50
            self.best_scores = [0] * 5
        elif self.hparams.dataset == "SONYCUST":
            self.dataset_folder = config.path_to_SONYCUST
            self.nb_classes = 31
            self.best_scores = [0] * 10
        else:
            None

        #
        # Settings for the SED models
        model_param = {"classes_num": self.nb_classes}

        self.model = Cnn10(**model_param)

        # Load every pretrained layers matching our layers
        pretrained_dict = torch.load(config.audiosetCNN10,
                                     map_location="cpu")["model"]
        model_dict = self.model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.size() == model_dict[k].size()
        }
        model_dict.update(pretrained_dict)
        self.model.load_state_dict(model_dict)

        if self.hparams.dataset != "SONYCUST":
            self.loss = BCELoss(reduction="none")
        else:
            self.loss_c = BCELoss(reduction="none")
            self.loss_f = Masked_loss(BCELoss(reduction="none"))
        self.mixup_augmenter = Mixup(self.hparams.alpha, self.hparams.seed)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        parser.add_argument("--batch_size", type=int, default=64)
        parser.add_argument("--shuffle", type=bool, default=True)
        parser.add_argument("--init_lr", type=float, default=1e-3)
        parser.add_argument("--alpha", type=float, default=1.0)

        parser.add_argument("--num_mels", type=int, default=64)
        parser.add_argument("--dataset",
                            type=str,
                            default="SONYCUST",
                            choices=["US8K", "ESC50", "SONYCUST"])

        return parser

    def forward(self, x, mixup_lambda=None):
        x = self.model(x, mixup_lambda)
        return x

    def prepare_data(self):
        # Dataset parameters

        # Creating dataset
        if self.hparams.dataset == "US8K":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
            }
            self.dataset = UrbanSound8K_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()

        elif self.hparams.dataset == "ESC50":
            data_param = {
                "dataset_folder": self.dataset_folder,
                "fold": self.fold,
            }
            self.dataset = ESC50_TALNet(**data_param)
            (self.train_dataset,
             self.val_dataset) = self.dataset.train_validation_split()
        elif self.hparams.dataset == "SONYCUST":
            data_param = {
                "sonycust_folder": self.dataset_folder,
                "mode": "both",
                "cleaning_strat": "DCASE",
            }
            self.dataset = SONYCUST_TALNet(**data_param)
            self.train_dataset, self.val_dataset, _ = self.dataset.train_validation_test_split(
            )
        else:
            None

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          batch_size=self.hparams.batch_size,
                          shuffle=self.hparams.shuffle,
                          num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.hparams.batch_size,
                          num_workers=4)

    def configure_optimizers(self):
        """
        optim_param = {
            'lr': self.hparams.init_lr
            }
        optimizer = Adam(self.model.parameters(), **optim_param)
        """
        base_optim_param = {"lr": self.hparams.init_lr}
        base_optim = Ralamb(self.model.parameters(), **base_optim_param)
        optim_param = {"k": 5, "alpha": 0.5}
        optimizer = Lookahead(base_optim, **optim_param)

        return optimizer

    def training_step(self, batch, batch_idx):
        if len(batch["label"]) % 2 == 0:
            lambdas = torch.FloatTensor(
                self.mixup_augmenter.get_lambda(
                    batch_size=len(batch["label"]))).to(self.device)
            is_odd = True
        else:
            lambdas = None
            is_odd = False
        if self.hparams.dataset != "SONYCUST":
            data, target = batch["input_vector"].float(), batch["label"].float(
            )
            output = self.forward(data, lambdas)
            if is_odd:
                target = do_mixup(target, lambdas)
            loss = self.loss(output, target).mean()
        else:
            data, target_c, target_f = (
                batch["input_vector"].float(),
                batch["label"]["coarse"].float(),
                batch["label"]["full_fine"].float(),
            )
            target = torch.cat([target_c, target_f], 1)
            if is_odd:
                target = do_mixup(target, lambdas)
            target_c, target_f = torch.split(target, [8, 29], 1)
            output = self.forward(data, lambdas)
            outputs_c, outputs_f = torch.split(output, [8, 23], 1)
            loss = torch.cat(
                [
                    self.loss_c(outputs_c, target_c).mean(0),
                    self.loss_f(outputs_f, target_f),
                ],
                0,
            ).mean()

        return {"loss": loss, "log": {"1_loss/train_loss": loss}}

    def validation_step(self, batch, batch_idx):
        if self.hparams.dataset != "SONYCUST":
            data, target = batch["input_vector"].float(), batch["label"].float(
            )
            output = self.forward(data)
            # Compute loss of the batch
            loss = self.loss(output, target)
        else:
            data, target_c, target_f = (
                batch["input_vector"].float(),
                batch["label"]["coarse"].float(),
                batch["label"]["full_fine"].float(),
            )
            target = torch.cat([target_c, target_f], 1)
            output = self.forward(data)
            outputs_c, outputs_f = torch.split(output, [8, 23], 1)
            # Compute loss of the batch
            loss = torch.cat(
                [
                    self.loss_c(outputs_c, target_c).mean(0),
                    self.loss_f(outputs_f, target_f),
                ],
                0,
            )

        return {
            "val_loss": loss,
            "output": output,
            "target": target,
        }

    def validation_epoch_end(self, outputs):
        val_loss = torch.cat([o["val_loss"] for o in outputs], 0).mean()
        all_outputs = torch.cat([o["output"] for o in outputs],
                                0).cpu().numpy()
        all_targets = torch.cat([o["target"] for o in outputs],
                                0).cpu().numpy()

        if self.hparams.dataset == "SONYCUST":
            # Logic for SONYCUST
            X_mask = ~torch.BoolTensor([
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                1,
                0,
                0,
                0,
                0,
                1,
                0,
            ])
            outputs_split = np.split(all_outputs, [8, 31], 1)
            all_outputs_coarse, all_outputs_fine = outputs_split[
                0], outputs_split[1]

            all_targets = all_targets[:, X_mask]
            targets_split = np.split(all_targets, [8, 31], 1)
            all_targets_coarse, all_targets_fine = targets_split[
                0], targets_split[1]

            accuracy_c = accuracy(all_targets_coarse, all_outputs_coarse)
            f1_micro_c = compute_micro_F1(all_targets_coarse,
                                          all_outputs_coarse)
            auprc_micro_c = compute_micro_auprc(all_targets_coarse,
                                                all_outputs_coarse)
            auprc_macro_c = compute_macro_auprc(all_targets_coarse,
                                                all_outputs_coarse)
            map_coarse = mean_average_precision(all_targets_coarse,
                                                all_outputs_coarse)

            accuracy_f = accuracy(all_targets_fine, all_outputs_fine)
            f1_micro_f = compute_micro_F1(all_targets_fine, all_outputs_fine)
            auprc_micro_f = compute_micro_auprc(all_targets_fine,
                                                all_outputs_fine)
            auprc_macro_f = compute_macro_auprc(all_targets_fine,
                                                all_outputs_fine)
            map_fine = mean_average_precision(all_targets_fine,
                                              all_outputs_fine)

            if accuracy_c > self.best_scores[0]:
                self.best_scores[0] = accuracy_c
            if f1_micro_c > self.best_scores[1]:
                self.best_scores[1] = f1_micro_c
            if auprc_micro_c > self.best_scores[2]:
                self.best_scores[2] = auprc_micro_c
            if auprc_macro_c > self.best_scores[3]:
                self.best_scores[3] = auprc_macro_c
            if map_coarse > self.best_scores[4]:
                self.best_scores[4] = map_coarse

            if accuracy_f > self.best_scores[5]:
                self.best_scores[5] = accuracy_f
            if f1_micro_f > self.best_scores[6]:
                self.best_scores[6] = f1_micro_f
            if auprc_micro_f > self.best_scores[7]:
                self.best_scores[7] = auprc_micro_f
            if auprc_macro_f > self.best_scores[8]:
                self.best_scores[8] = auprc_macro_f
            if map_fine > self.best_scores[9]:
                self.best_scores[9] = map_fine

            log_temp = {
                "2_valid_coarse/[email protected]": accuracy_c,
                "2_valid_coarse/[email protected]": f1_micro_c,
                "2_valid_coarse/1_auprc_micro": auprc_micro_c,
                "2_valid_coarse/1_auprc_macro": auprc_macro_c,
                "2_valid_coarse/1_map_coarse": map_coarse,
                "3_valid_fine/[email protected]": accuracy_f,
                "3_valid_fine/[email protected]": f1_micro_f,
                "3_valid_fine/1_auprc_micro": auprc_micro_f,
                "3_valid_fine/1_auprc_macro": auprc_macro_f,
                "3_valid_fine/1_map_fine": map_fine,
            }

            tqdm_dict = {
                "val_loss": val_loss,
                "m_auprc_c": auprc_macro_c,
            }

        else:
            # Logic for ESC50 and US8K
            accuracy_score = accuracy(all_targets, all_outputs)
            f1_micro = compute_micro_F1(all_targets, all_outputs)
            auprc_micro = compute_micro_auprc(all_targets, all_outputs)
            _, auprc_macro = compute_macro_auprc(all_targets, all_outputs,
                                                 True)
            map_score = mean_average_precision(all_targets, all_outputs)

            if accuracy_score > self.best_scores[0]:
                self.best_scores[0] = accuracy_score
            if f1_micro > self.best_scores[1]:
                self.best_scores[1] = f1_micro
            if auprc_micro > self.best_scores[2]:
                self.best_scores[2] = auprc_micro
            if auprc_macro > self.best_scores[3]:
                self.best_scores[3] = auprc_macro
            if map_score > self.best_scores[4]:
                self.best_scores[4] = map_score

            log_temp = {
                "2_valid/1_accuracy0.5": accuracy_score,
                "2_valid/1_f1_micro0.5": f1_micro,
                "2_valid/1_auprc_micro": auprc_micro,
                "2_valid/1_auprc_macro": auprc_macro,
                "2_valid/1_map": map_score,
            }

            tqdm_dict = {
                "val_loss": val_loss,
                "acc": accuracy_score,
            }

        log = {
            "step": self.current_epoch,
            "1_loss/val_loss": val_loss,
        }

        log.update(log_temp)

        return {"progress_bar": tqdm_dict, "log": log}