Exemple #1
0
    def run_for_epoch(self, epoch, x, y, type):
        self.network.eval()
        # for m in self.network.modules():
        #     if isinstance(m, nn.BatchNorm2d):
        #         m.track_running_stats = False
        self.test_batch_loss, predictions, latent_features = [], [], []

        with torch.no_grad():
            for i, (audio_data, label) in enumerate(zip(x, y)):
                audio_data = to_tensor(audio_data, device=self.device)
                test_predictions, test_latent = self.network(audio_data)
                test_predictions = test_predictions.squeeze(1)
                latent_features.extend(to_numpy(test_latent.squeeze(1)))
                test_loss = self.loss_function(test_predictions, audio_data)
                self.test_batch_loss.append(to_numpy(test_loss))

        oneclass_predictions = self.oneclass_svm.predict(latent_features)
        masked_predictions = self.mask_preds_for_one_class(
            oneclass_predictions)
        test_metrics = accuracy_fn(
            to_tensor(masked_predictions),
            to_tensor([element for sublist in y for element in sublist]),
            threshold=self.threshold)
        test_metrics = {'test_' + k: v for k, v in test_metrics.items()}
        self.logger.info(f'***** {type} Metrics ***** ')
        self.logger.info(
            f"Loss: {'%.5f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.5f' % test_metrics['test_accuracy']} "
            f"| UAR: {'%.5f' % test_metrics['test_uar']}| F1:{'%.5f' % test_metrics['test_f1']} "
            f"| Precision:{'%.5f' % test_metrics['test_precision']} "
            f"| Recall:{'%.5f' % test_metrics['test_recall']} | AUC:{'%.5f' % test_metrics['test_auc']}"
        )
        wnb.log(test_metrics)

        write_to_npy(filename=self.debug_filename,
                     predictions=predictions,
                     labels=y,
                     epoch=epoch,
                     accuracy=test_metrics['test_accuracy'],
                     loss=np.mean(self.test_batch_loss),
                     uar=test_metrics['test_auc'],
                     precision=test_metrics['test_precision'],
                     recall=test_metrics['test_recall'],
                     auc=test_metrics['test_auc'],
                     lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                     type=type)
        if epoch + 1 == self.epochs:
            wnb.log({
                "test_cf":
                wnb.sklearn.plot_confusion_matrix(
                    y_true=[label for sublist in y for label in sublist],
                    y_pred=masked_predictions,
                    labels=['Negative', 'Positive'])
            })
Exemple #2
0
 def test(self):
     test_data, test_labels = self.data_reader(
         self.data_read_path + 'test_challenge_data.npy',
         self.data_read_path + 'test_challenge_labels.npy',
         shuffle=False,
         train=False)
     test_predictions = self.network(test_data).squeeze(1)
     test_predictions = nn.Sigmoid()(test_predictions)
     test_accuracy, test_uar = accuracy_fn(test_predictions, test_labels,
                                           self.threshold)
     self.logger.info(f"Accuracy: {test_accuracy} | UAR: {test_uar}")
     self.logger.info(f"Accuracy: {test_accuracy} | UAR: {test_uar}")
# x = np.linspace(mu - 3 * sigma, mu + 3 * sigma, 100)
# plt.plot(x, stats.norm.pdf(x, mu, sigma))
print(train_latent_features[ones_idx].mean(axis=1).shape)
plt.hist(train_latent_features[ones_idx][0])
plt.show()

exit()
# exit()
# model = svm.OneClassSVM(kernel="poly")
# oneclass_svm = IsolationForest(random_state=0)
model = EllipticEnvelope()
model.fit(train_latent_features)
oneclass_predictions = model.predict(train_latent_features)
masked_predictions = mask_preds_for_one_class(oneclass_predictions)
train_metrics = accuracy_fn(to_tensor(masked_predictions),
                            to_tensor(train_labels),
                            threshold=threshold)
train_metrics = {'train_' + k: v for k, v in train_metrics.items()}
print(f'***** Train Metrics ***** ')
print(
    f"Accuracy: {'%.5f' % train_metrics['train_accuracy']} "
    f"| UAR: {'%.5f' % train_metrics['train_uar']}| F1:{'%.5f' % train_metrics['train_f1']} "
    f"| Precision:{'%.5f' % train_metrics['train_precision']} "
    f"| Recall:{'%.5f' % train_metrics['train_recall']} | AUC:{'%.5f' % train_metrics['train_auc']}"
)
print('Train Confusion matrix - \n' +
      str(confusion_matrix(train_labels, masked_predictions)))

# Test
oneclass_predictions = model.predict(test_latent_features)
masked_predictions = mask_preds_for_one_class(oneclass_predictions)
Exemple #4
0
    def train(self):

        train_data, train_labels = self.data_reader(self.data_read_path,
                                                    [self.train_file],
                                                    shuffle=True,
                                                    train=True)
        test_data, test_labels = self.data_reader(self.data_read_path,
                                                  [self.test_file],
                                                  shuffle=False,
                                                  train=False)

        # For the purposes of assigning pos weight on the fly we are initializing the cost function here
        # self.loss_function = nn.BCEWithLogitsLoss(pos_weight=to_tensor(self.pos_weight, device=self.device))
        self.loss_function = nn.MSELoss()

        for epoch in range(1, self.epochs):
            self.oneclass_svm = svm.OneClassSVM(nu=0.1,
                                                kernel="rbf",
                                                gamma=0.1)
            self.network.train()
            self.latent_features, train_predictions = [], []

            for i, (audio_data,
                    label) in enumerate(zip(train_data, train_labels)):
                self.optimiser.zero_grad()

                audio_data = to_tensor(audio_data, device=self.device)
                predictions, latent = self.network(audio_data)
                predictions = predictions.squeeze(1)
                self.latent_features.extend(to_numpy(latent.squeeze(1)))
                loss = self.loss_function(predictions, audio_data)
                loss.backward()
                self.optimiser.step()
                self.batch_loss.append(to_numpy(loss))

            oneclass_predictions = self.oneclass_svm.fit_predict(
                self.latent_features)
            masked_predictions = self.mask_preds_for_one_class(
                oneclass_predictions)
            train_metrics = accuracy_fn(to_tensor(masked_predictions),
                                        to_tensor([
                                            element for sublist in train_labels
                                            for element in sublist
                                        ]),
                                        threshold=self.threshold)
            wnb.log(train_metrics)
            wnb.log({'reconstruction_loss': np.mean(self.batch_loss)})

            # Decay learning rate
            self.scheduler.step(epoch=epoch)
            wnb.log(
                {"LR": self.optimiser.state_dict()['param_groups'][0]['lr']})

            self.logger.info('***** Overall Train Metrics ***** ')
            self.logger.info(
                f"Epoch: {epoch} | Loss: {'%.5f' % np.mean(self.batch_loss)} | Accuracy: {'%.5f' % train_metrics['accuracy']} "
                f"| UAR: {'%.5f' % train_metrics['uar']} | F1:{'%.5f' % train_metrics['f1']} "
                f"| Precision:{'%.5f' % train_metrics['precision']} | Recall:{'%.5f' % train_metrics['recall']} "
                f"| AUC:{'%.5f' % train_metrics['auc']}")

            # test data
            self.run_for_epoch(epoch, test_data, test_labels, type='Test')

            self.oneclass_svm = None  # Clearing the model for every epoch

            if epoch % self.network_save_interval == 0:
                save_path = self.network_save_path + '/' + self.run_name + '_' + str(
                    epoch) + '.pt'
                save_path_oneclass_svm = self.network_save_path + '/oneclass_svm' + str(
                    epoch) + '.pkl'
                torch.save(self.network.state_dict(), save_path)
                pickle.dump(self.oneclass_svm,
                            open(save_path_oneclass_svm, 'wb'))
                self.logger.info(f'Network successfully saved: {save_path}')
Exemple #5
0
    def infer(self):
        from sklearn import svm
        from sklearn.metrics import confusion_matrix
        import pickle
        self._min, self._max = -80.0, 3.8146973e-06
        train_data, train_labels = self.data_reader(
            self.data_read_path, [self.train_file],
            shuffle=False,
            train=True,
            only_negative_samples=False)

        test_data, test_labels = self.data_reader(self.data_read_path,
                                                  [self.test_file],
                                                  shuffle=False,
                                                  train=False,
                                                  only_negative_samples=False)
        train_latent_features, test_latent_features = [], []
        with torch.no_grad():
            for i, (audio_data,
                    label) in enumerate(zip(train_data, train_labels)):
                audio_data = to_tensor(audio_data, device=self.device)
                _, _, _, train_latent = self.network(audio_data)
                train_latent_features.extend(to_numpy(train_latent.squeeze(1)))
        pickle.dump(train_latent_features,
                    open('vae_forced_train_latent.npy', 'wb'))

        oneclass_svm = svm.OneClassSVM(nu=0.1, kernel="poly", gamma=0.1)
        oneclass_svm.fit(train_latent_features)
        oneclass_predictions = oneclass_svm.predict(train_latent_features)
        masked_predictions = self.mask_preds_for_one_class(
            oneclass_predictions)
        train_metrics = accuracy_fn(to_tensor(masked_predictions),
                                    to_tensor([
                                        element for sublist in train_labels
                                        for element in sublist
                                    ]),
                                    threshold=self.threshold)
        train_metrics = {'train_' + k: v for k, v in train_metrics.items()}
        self.logger.info(f'***** Train Metrics ***** ')
        self.logger.info(
            f"Accuracy: {'%.5f' % train_metrics['train_accuracy']} "
            f"| UAR: {'%.5f' % train_metrics['train_uar']}| F1:{'%.5f' % train_metrics['train_f1']} "
            f"| Precision:{'%.5f' % train_metrics['train_precision']} "
            f"| Recall:{'%.5f' % train_metrics['train_recall']} | AUC:{'%.5f' % train_metrics['train_auc']}"
        )
        self.logger.info('Train Confusion matrix - \n' + str(
            confusion_matrix(
                [element for sublist in train_labels
                 for element in sublist], masked_predictions)))

        # Test
        with torch.no_grad():
            for i, (audio_data,
                    label) in enumerate(zip(test_data, test_labels)):
                audio_data = to_tensor(audio_data, device=self.device)
                _, _, _, test_latent = self.network(audio_data)
                test_latent_features.extend(to_numpy(test_latent.squeeze(1)))
        pickle.dump(test_latent_features,
                    open('vae_forced_test_latent.npy', 'wb'))

        oneclass_predictions = oneclass_svm.predict(test_latent_features)
        masked_predictions = self.mask_preds_for_one_class(
            oneclass_predictions)
        test_metrics = accuracy_fn(to_tensor(masked_predictions),
                                   to_tensor([
                                       element for sublist in test_labels
                                       for element in sublist
                                   ]),
                                   threshold=self.threshold)
        test_metrics = {'test_' + k: v for k, v in test_metrics.items()}
        self.logger.info(f'***** Test Metrics ***** ')
        self.logger.info(
            f"Accuracy: {'%.5f' % test_metrics['test_accuracy']} "
            f"| UAR: {'%.5f' % test_metrics['test_uar']}| F1:{'%.5f' % test_metrics['test_f1']} "
            f"| Precision:{'%.5f' % test_metrics['test_precision']} "
            f"| Recall:{'%.5f' % test_metrics['test_recall']} | AUC:{'%.5f' % test_metrics['test_auc']}"
        )
        self.logger.info('Test Confusion matrix - \n' + str(
            confusion_matrix(
                [element for sublist in test_labels
                 for element in sublist], masked_predictions)))
Exemple #6
0
    def run_for_epoch(self, epoch, x, y, type):
        self.network.eval()
        # for m in self.network.modules():
        #     if isinstance(m, nn.BatchNorm2d):
        #         m.track_running_stats = False
        predictions_dict = {"tp": [], "fp": [], "tn": [], "fn": []}
        logits, predictions = [], []
        self.test_batch_loss, self.test_batch_accuracy, self.test_batch_uar, self.test_batch_ua, self.test_batch_f1, \
        self.test_batch_precision, self.test_batch_recall, self.test_batch_auc, audio_for_tensorboard_test = [], [], \
                                                                                                             [], [], [], [], [], [], None
        with torch.no_grad():
            for i, (audio_data, label) in enumerate(zip(x, y)):
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                test_predictions = self.network(audio_data).squeeze(1)
                logits.extend(to_numpy(test_predictions))
                test_loss = self.loss_function(test_predictions, label)
                test_predictions = nn.Sigmoid()(test_predictions)
                predictions.append(to_numpy(test_predictions))
                test_batch_metrics = accuracy_fn(test_predictions, label,
                                                 self.threshold)
                self.test_batch_loss.append(to_numpy(test_loss))
                self.test_batch_accuracy.append(
                    to_numpy(test_batch_metrics['accuracy']))
                self.test_batch_uar.append(test_batch_metrics['uar'])
                self.test_batch_f1.append(test_batch_metrics['f1'])
                self.test_batch_precision.append(
                    test_batch_metrics['precision'])
                self.test_batch_recall.append(test_batch_metrics['recall'])
                self.test_batch_auc.append(test_batch_metrics['auc'])
                tp, fp, tn, fn = custom_confusion_matrix(
                    test_predictions, label, threshold=self.threshold)
                predictions_dict['tp'].extend(tp)
                predictions_dict['fp'].extend(fp)
                predictions_dict['tn'].extend(tn)
                predictions_dict['fn'].extend(fn)

        predictions = [
            element for sublist in predictions for element in sublist
        ]
        self.logger.info(f'***** {type} Metrics ***** ')
        self.logger.info(
            f"Loss: {'%.5f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.5f' % np.mean(self.test_batch_accuracy)} "
            f"| UAR: {'%.5f' % np.mean(self.test_batch_uar)}| F1:{'%.5f' % np.mean(self.test_batch_f1)} "
            f"| Precision:{'%.5f' % np.mean(self.test_batch_precision)} "
            f"| Recall:{'%.5f' % np.mean(self.test_batch_recall)} | AUC:{'%.5f' % np.mean(self.test_batch_auc)}"
        )
        epoch_test_batch_metrics = {
            "test_loss": np.mean(self.test_batch_loss),
            "test_accuracy": np.mean(self.test_batch_accuracy),
            "test_uar": np.mean(self.test_batch_uar),
            "test_f1": np.mean(self.test_batch_f1),
            "test_precision": np.mean(self.test_batch_precision),
            "test_recall": np.mean(self.test_batch_recall),
            "test_auc": np.mean(self.test_batch_auc)
        }
        wnb.log(epoch_test_batch_metrics)
        wnb.log({
            "test_cf":
            wnb.sklearn.plot_confusion_matrix(
                y_true=[label for sublist in y for label in sublist],
                y_pred=np.where(np.array(predictions) > self.threshold, 1, 0),
                labels=['Negative', 'Positive'])
        })
        # log_summary(self.writer, epoch, accuracy=np.mean(self.test_batch_accuracy),
        #             loss=np.mean(self.test_batch_loss),
        #             uar=np.mean(self.test_batch_uar), precision=np.mean(self.test_batch_precision),
        #             recall=np.mean(self.test_batch_recall), auc=np.mean(self.test_batch_auc),
        #             lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
        #             type=type)
        # log_conf_matrix(self.writer, epoch, predictions_dict=predictions_dict, type=type)
        #
        # log_learnable_parameter(self.writer, epoch, to_tensor(logits, device=self.device),
        #                         name=f'{type}_logits')
        # log_learnable_parameter(self.writer, epoch, to_tensor(predictions, device=self.device),
        #                         name=f'{type}_predictions')

        write_to_npy(filename=self.debug_filename,
                     predictions=predictions,
                     labels=y,
                     epoch=epoch,
                     accuracy=np.mean(self.test_batch_accuracy),
                     loss=np.mean(self.test_batch_loss),
                     uar=np.mean(self.test_batch_uar),
                     precision=np.mean(self.test_batch_precision),
                     recall=np.mean(self.test_batch_recall),
                     auc=np.mean(self.test_batch_auc),
                     lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                     predictions_dict=predictions_dict,
                     type=type)
Exemple #7
0
    def train(self):
        train_data, train_labels = self.data_reader(self.data_read_path,
                                                    [self.train_file],
                                                    shuffle=True,
                                                    train=True)
        test_data, test_labels = self.data_reader(self.data_read_path,
                                                  [self.test_file],
                                                  shuffle=False,
                                                  train=False)

        # For the purposes of assigning pos weight on the fly we are initializing the cost function here
        self.loss_function = nn.BCEWithLogitsLoss(
            pos_weight=to_tensor(self.pos_weight, device=self.device))

        total_step = len(train_data)
        for epoch in range(1, self.epochs):
            self.network.train()
            self.batch_loss, self.batch_accuracy, self.batch_uar, self.batch_f1, self.batch_precision, \
            self.batch_recall, self.batch_auc, train_predictions, train_logits, audio_for_tensorboard_train = [], [], [], [], [], [], [], [], [], None
            for i, (audio_data,
                    label) in enumerate(zip(train_data, train_labels)):
                self.optimiser.zero_grad()
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                predictions = self.network(audio_data).squeeze(1)
                train_logits.extend(predictions)
                loss = self.loss_function(predictions, label)
                predictions = nn.Sigmoid()(predictions)
                train_predictions.extend(predictions)
                loss.backward()
                self.optimiser.step()
                self.batch_loss.append(to_numpy(loss))

                batch_metrics = accuracy_fn(predictions, label, self.threshold)
                batch_metrics['loss'] = to_numpy(loss)
                self.batch_accuracy.append(to_numpy(batch_metrics['accuracy']))
                self.batch_uar.append(batch_metrics['uar'])
                self.batch_f1.append(batch_metrics['f1'])
                self.batch_precision.append(batch_metrics['precision'])
                self.batch_recall.append(batch_metrics['recall'])
                self.batch_auc.append(batch_metrics['auc'])

                wnb.log(batch_metrics)

                if i % self.display_interval == 0:
                    self.logger.info(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {'%.5f' % loss} | "
                        f"Accuracy: {'%.5f' % batch_metrics['accuracy']} | UAR: {'%.5f' % batch_metrics['uar']}| "
                        f"F1:{'%.5f' % batch_metrics['f1']} | Precision: {'%.5f' % batch_metrics['precision']} | "
                        f"Recall: {'%.5f' % batch_metrics['recall']} | AUC: {'%.5f' % batch_metrics['auc']}"
                    )

            # log_learnable_parameter(self.writer, epoch, to_tensor(train_logits, device=self.device),
            #                         name='train_logits')
            # log_learnable_parameter(self.writer, epoch, to_tensor(train_predictions, device=self.device),
            #                         name='train_activated')

            # Decay learning rate
            self.scheduler.step(epoch=epoch)
            wnb.log(
                {"LR": self.optimiser.state_dict()['param_groups'][0]['lr']})
            # log_summary(self.writer, epoch, accuracy=np.mean(self.batch_accuracy),
            #             loss=np.mean(self.batch_loss),
            #             uar=np.mean(self.batch_uar), precision=np.mean(self.batch_precision),
            #             recall=np.mean(self.batch_recall),
            #             auc=np.mean(self.batch_auc), lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
            #             type='Train')
            self.logger.info('***** Overall Train Metrics ***** ')
            self.logger.info(
                f"Loss: {'%.5f' % np.mean(self.batch_loss)} | Accuracy: {'%.5f' % np.mean(self.batch_accuracy)} "
                f"| UAR: {'%.5f' % np.mean(self.batch_uar)} | F1:{'%.5f' % np.mean(self.batch_f1)} "
                f"| Precision:{'%.5f' % np.mean(self.batch_precision)} | Recall:{'%.5f' % np.mean(self.batch_recall)} "
                f"| AUC:{'%.5f' % np.mean(self.batch_auc)}")

            # test data
            self.run_for_epoch(epoch, test_data, test_labels, type='Test')

            if epoch % self.network_save_interval == 0:
                save_path = self.network_save_path + '/' + self.run_name + '_' + str(
                    epoch) + '.pt'
                torch.save(self.network.state_dict(), save_path)
                self.logger.info(f'Network successfully saved: {save_path}')
Exemple #8
0
    def infer(self):
        # from sklearn import svm
        # from sklearn.ensemble import IsolationForest
        # from sklearn.metrics import confusion_matrix
        # import pickle as pk
        # train_labels, test_labels = pk.load(open(
        #         '/Users/badgod/badgod_documents/Datasets/covid19/processed_data/coswara_train_data_fbank_cough-shallow_labels.pkl',
        #         'rb')), pk.load(open(
        #         '/Users/badgod/badgod_documents/Datasets/covid19/processed_data/coswara_test_data_fbank_cough-shallow_labels.pkl',
        #         'rb'))
        # train_latent_features, test_latent_features = pk.load(
        #         open('/Users/badgod/badgod_documents/Datasets/covid19/processed_data/forced_train_latent.npy',
        #              'rb')), pk.load(
        #         open('/Users/badgod/badgod_documents/Datasets/covid19/processed_data/forced_test_latent.npy', 'rb'))
        # # for x, y in zip(train_latent_features, train_labels):
        # #     if y == 0:
        # #         print('Mean: ', np.mean(x), ' Std: ', np.std(x), ' | Label: ', y)
        # # for x, y in zip(train_latent_features, train_labels):
        # #     if y == 1:
        # #         print('Mean: ', np.mean(x), ' Std: ', np.std(x), ' | Label: ', y)
        # #
        # # exit()
        # self.logger.info(
        #         'Total train data len: ' + str(len(train_labels)) + ' | Positive samples: ' + str(sum(train_labels)))
        # self.logger.info(
        #         'Total test data len: ' + str(len(test_labels)) + ' | Positive samples: ' + str(sum(test_labels)))
        # # oneclass_svm = svm.OneClassSVM(kernel="rbf")
        # oneclass_svm = IsolationForest(random_state=0)
        # oneclass_svm.fit(train_latent_features)
        # oneclass_predictions = oneclass_svm.predict(train_latent_features)
        # masked_predictions = self.mask_preds_for_one_class(oneclass_predictions)
        # train_metrics = accuracy_fn(to_tensor(masked_predictions), to_tensor(train_labels), threshold=self.threshold)
        # train_metrics = {'train_' + k: v for k, v in train_metrics.items()}
        # self.logger.info(f'***** Train Metrics ***** ')
        # self.logger.info(
        #         f"Accuracy: {'%.5f' % train_metrics['train_accuracy']} "
        #         f"| UAR: {'%.5f' % train_metrics['train_uar']}| F1:{'%.5f' % train_metrics['train_f1']} "
        #         f"| Precision:{'%.5f' % train_metrics['train_precision']} "
        #         f"| Recall:{'%.5f' % train_metrics['train_recall']} | AUC:{'%.5f' % train_metrics['train_auc']}")
        # self.logger.info('Train Confusion matrix - \n' + str(confusion_matrix(train_labels, masked_predictions)))
        # # Test
        # oneclass_predictions = oneclass_svm.predict(test_latent_features)
        # masked_predictions = self.mask_preds_for_one_class(oneclass_predictions)
        # test_metrics = accuracy_fn(to_tensor(masked_predictions), to_tensor(test_labels), threshold=self.threshold)
        # test_metrics = {'test_' + k: v for k, v in test_metrics.items()}
        # self.logger.info(f'***** Test Metrics ***** ')
        # self.logger.info(
        #         f"Accuracy: {'%.5f' % test_metrics['test_accuracy']} "
        #         f"| UAR: {'%.5f' % test_metrics['test_uar']}| F1:{'%.5f' % test_metrics['test_f1']} "
        #         f"| Precision:{'%.5f' % test_metrics['test_precision']} "
        #         f"| Recall:{'%.5f' % test_metrics['test_recall']} | AUC:{'%.5f' % test_metrics['test_auc']}")
        # self.logger.info('Test Confusion matrix - \n' + str(confusion_matrix(test_labels, masked_predictions)))

        from sklearn import svm
        from sklearn.metrics import confusion_matrix
        import pickle
        self._min, self._max = -80.0, 3.8146973e-06
        train_data, train_labels = self.data_reader(self.data_read_path, [self.train_file],
                                                    shuffle=False,
                                                    train=True, only_negative_samples=False)

        test_data, test_labels = self.data_reader(self.data_read_path, [self.test_file],
                                                  shuffle=False,
                                                  train=False, only_negative_samples=False)
        train_latent_features, test_latent_features = [], []
        with torch.no_grad():
            for i, (audio_data, label) in enumerate(zip(train_data, train_labels)):
                audio_data = to_tensor(audio_data, device=self.device)
                train_predictions, train_latent = self.network(audio_data)
                train_latent_features.extend(to_numpy(train_latent.squeeze(1)))
        pickle.dump(train_latent_features,
                    open('ae_contrastive_train_latent.npy', 'wb'))

        oneclass_svm = svm.OneClassSVM(nu=0.1, kernel="poly", gamma=0.1)
        oneclass_svm.fit(train_latent_features)
        oneclass_predictions = oneclass_svm.predict(train_latent_features)
        masked_predictions = self.mask_preds_for_one_class(oneclass_predictions)
        train_metrics = accuracy_fn(to_tensor(masked_predictions),
                                    to_tensor([element for sublist in train_labels for element in sublist]),
                                    threshold=self.threshold)
        train_metrics = {'train_' + k: v for k, v in train_metrics.items()}
        self.logger.info(f'***** Train Metrics ***** ')
        self.logger.info(
                f"Accuracy: {'%.5f' % train_metrics['train_accuracy']} "
                f"| UAR: {'%.5f' % train_metrics['train_uar']}| F1:{'%.5f' % train_metrics['train_f1']} "
                f"| Precision:{'%.5f' % train_metrics['train_precision']} "
                f"| Recall:{'%.5f' % train_metrics['train_recall']} | AUC:{'%.5f' % train_metrics['train_auc']}")
        self.logger.info('Train Confusion matrix - \n' + str(
                confusion_matrix([element for sublist in train_labels for element in sublist], masked_predictions)))

        # Test
        with torch.no_grad():
            for i, (audio_data, label) in enumerate(zip(test_data, test_labels)):
                audio_data = to_tensor(audio_data, device=self.device)
                test_predictions, test_latent = self.network(audio_data)
                test_latent_features.extend(to_numpy(test_latent.squeeze(1)))
        pickle.dump(test_latent_features,
                    open('ae_contrastive_test_latent.npy', 'wb'))

        oneclass_predictions = oneclass_svm.predict(test_latent_features)
        masked_predictions = self.mask_preds_for_one_class(oneclass_predictions)
        test_metrics = accuracy_fn(to_tensor(masked_predictions),
                                   to_tensor([element for sublist in test_labels for element in sublist]),
                                   threshold=self.threshold)
        test_metrics = {'test_' + k: v for k, v in test_metrics.items()}
        self.logger.info(f'***** Test Metrics ***** ')
        self.logger.info(
                f"Accuracy: {'%.5f' % test_metrics['test_accuracy']} "
                f"| UAR: {'%.5f' % test_metrics['test_uar']}| F1:{'%.5f' % test_metrics['test_f1']} "
                f"| Precision:{'%.5f' % test_metrics['test_precision']} "
                f"| Recall:{'%.5f' % test_metrics['test_recall']} | AUC:{'%.5f' % test_metrics['test_auc']}")
        self.logger.info('Test Confusion matrix - \n' + str(
                confusion_matrix([element for sublist in test_labels for element in sublist], masked_predictions)))

        train_latent_features = np.array(train_latent_features)
        test_latent_features = np.array(test_latent_features)
        ones_idx = [i for i, x in enumerate(train_labels) if x == 1]
        zeros_idx = [i for i, x in enumerate(train_labels) if x == 0]
        print(train_latent_features[ones_idx].mean(), train_latent_features[ones_idx].std())
        print(train_latent_features[zeros_idx].mean(), train_latent_features[zeros_idx].std())

        ones_idx = [i for i, x in enumerate(test_labels) if x == 1]
        zeros_idx = [i for i, x in enumerate(test_labels) if x == 0]
        print(test_latent_features[ones_idx].mean(), test_latent_features[ones_idx].std())
        print(test_latent_features[zeros_idx].mean(), test_latent_features[zeros_idx].std())