예제 #1
0
    def train(self):
        train_inp_file, train_out_file, train_jitter_file = 'train_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_data.npy', 'train_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_labels.npy', 'train_challenge_with_shimmer_jitter.npy'
        dev_inp_file, dev_out_file, dev_jitter_file = 'dev_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_data.npy', 'dev_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_labels.npy', 'dev_challenge_with_shimmer_jitter.npy'
        test_inp_file, test_out_file, test_jitter_file = 'test_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_data.npy', 'test_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_labels.npy', 'test_challenge_with_shimmer_jitter.npy'

        self.data_loaders['train'] = self.data_reader(
            self.data_read_path + train_inp_file,
            self.data_read_path + train_out_file,
            self.data_read_path + train_jitter_file,
            shuffle=True,
            train=True,
            type='train')
        # a,b = next(iter(data_loaders['train']))
        self.logger.info(f'Reading dev file {dev_inp_file, dev_out_file}')
        self.data_loaders['dev'] = self.data_reader(
            self.data_read_path + dev_inp_file,
            self.data_read_path + dev_out_file,
            self.data_read_path + dev_jitter_file,
            shuffle=False,
            train=False,
            type='dev')
        self.logger.info(f'Reading test file {test_inp_file, test_out_file}')
        self.data_loaders['test'] = self.data_reader(
            self.data_read_path + test_inp_file,
            self.data_read_path + test_out_file,
            self.data_read_path + test_jitter_file,
            shuffle=False,
            train=False,
            type='test')

        model_ft = models.resnet18(pretrained=True)
        for param in model_ft.parameters():
            param.requires_grad = False
        num_ftrs = model_ft.fc.in_features
        # # Here the size of each output sample is set to 2.
        # # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
        model_ft.fc = nn.Linear(num_ftrs, self.num_classes)
        model_ft = model_ft.to(self.device)
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=to_tensor(self.pos_weight, device=self.device))

        # Observe that all parameters are being optimized
        optimizer_ft = optim.Adam(model_ft.parameters(), lr=self.learning_rate)

        # Decay LR by a factor of 0.1 every 7 epochs
        exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft,
                                               step_size=7,
                                               gamma=0.1)

        self.train_model(model_ft,
                         criterion,
                         optimizer_ft,
                         exp_lr_scheduler,
                         num_epochs=self.epochs)
예제 #2
0
    def train(self):

        # For purposes of calculating normalized values, call this method with train data followed by test
        train_data, train_labels = self.data_reader(
            self.data_read_path +
            'train_challenge_with_d1_raw_16k_data_4sec.npy',
            self.data_read_path +
            'train_challenge_with_d1_raw_16k_labels_4sec.npy',
            shuffle=True,
            train=True)
        dev_data, dev_labels = self.data_reader(
            self.data_read_path +
            'dev_challenge_with_d1_raw_16k_data_4sec.npy',
            self.data_read_path +
            'dev_challenge_with_d1_raw_16k_labels_4sec.npy',
            shuffle=False,
            train=False)
        test_data, test_labels = self.data_reader(
            self.data_read_path +
            'test_challenge_with_d1_raw_16k_data_4sec.npy',
            self.data_read_path +
            'test_challenge_with_d1_raw_16k_labels_4sec.npy',
            shuffle=False,
            train=False)

        # For the purposes of assigning pos weight on the fly we are initializing the cost function here
        self.loss_function = nn.BCEWithLogitsLoss(
            pos_weight=to_tensor(self.pos_weight, device=self.device))

        total_step = len(train_data)
        for epoch in range(1, self.epochs):
            self.network.train()
            self.batch_loss, self.batch_accuracy, self.batch_uar, self.batch_f1, self.batch_precision, self.batch_recall, audio_for_tensorboard_train = [], [], [], [], [], [], None
            for i, (audio_data,
                    label) in enumerate(zip(train_data, train_labels)):
                self.optimiser.zero_grad()
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                # if i == 0:
                #     self.writer.add_graph(self.network, audio_data)
                predictions = self.network(audio_data).squeeze(1)
                loss = self.loss_function(predictions, label)
                predictions = nn.Sigmoid()(predictions)
                loss.backward()
                self.optimiser.step()
                accuracy, uar, precision, recall, f1 = accuracy_fn(
                    predictions, label, self.threshold)
                self.batch_loss.append(to_numpy(loss))
                self.batch_accuracy.append(to_numpy(accuracy))
                self.batch_uar.append(uar)
                self.batch_f1.append(f1)
                self.batch_precision.append(precision)
                self.batch_recall.append(recall)

                if i % self.display_interval == 0:
                    print(
                        "****************************************************************"
                    )
                    print("predictions mean",
                          torch.mean(predictions).detach().cpu().numpy())
                    print("predictions mean",
                          torch.mean(predictions).detach().cpu().numpy(),
                          file=self.log_file)
                    print("predictions sum",
                          torch.sum(predictions).detach().cpu().numpy())
                    print("predictions sum",
                          torch.sum(predictions).detach().cpu().numpy(),
                          file=self.log_file)
                    print("predictions range",
                          torch.min(predictions).detach().cpu().numpy(),
                          torch.max(predictions).detach().cpu().numpy())
                    print("predictions range",
                          torch.min(predictions).detach().cpu().numpy(),
                          torch.max(predictions).detach().cpu().numpy(),
                          file=self.log_file)
                    print("predictions hist",
                          np.histogram(predictions.detach().cpu().numpy()))
                    print("predictions hist",
                          np.histogram(predictions.detach().cpu().numpy()),
                          file=self.log_file)
                    print("predictions variance",
                          torch.var(predictions).detach().cpu().numpy())
                    print("predictions variance",
                          torch.var(predictions).detach().cpu().numpy(),
                          file=self.log_file)
                    print(
                        "****************************************************************"
                    )
                    print(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {'%.3f' % loss} | Accuracy: {'%.3f' % accuracy} | UAR: {'%.3f' % uar}| F1:{'%.3f' % f1} | Precision: {'%.3f' % precision} | Recall: {'%.3f' % recall}"
                    )
                    print(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {'%.3f' % loss} | Accuracy: {accuracy} | UAR: {'%.3f' % uar}| F1:{'%.3f' % f1} | Precision: {'%.3f' % precision} | Recall: {'%.3f' % recall}",
                        file=self.log_file)

            # Decay learning rate
            # self.scheduler.step(epoch=epoch)
            log_summary(
                self.writer,
                epoch,
                accuracy=np.mean(self.batch_accuracy),
                loss=np.mean(self.batch_loss),
                uar=np.mean(self.batch_uar),
                lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                type='Train')
            print('***** Overall Train Metrics ***** ')
            print('***** Overall Train Metrics ***** ', file=self.log_file)
            print(
                f"Loss: {'%.3f' % np.mean(self.batch_loss)} | Accuracy: {'%.3f' % np.mean(self.batch_accuracy)} | UAR: {'%.3f' % np.mean(self.batch_uar)} | F1:{'%.3f' % np.mean(self.batch_f1)} | Precision:{'%.3f' % np.mean(self.batch_precision)} | Recall:{'%.3f' % np.mean(self.batch_recall)}"
            )
            print(
                f"Loss: {'%.3f' % np.mean(self.batch_loss)} | Accuracy: {'%.3f' % np.mean(self.batch_accuracy)} | UAR: {'%.3f' % np.mean(self.batch_uar)} | F1:{'%.3f' % np.mean(self.batch_f1)} | Precision:{'%.3f' % np.mean(self.batch_precision)} | Recall:{'%.3f' % np.mean(self.batch_recall)}",
                file=self.log_file)
            print('Learning rate ',
                  self.optimiser.state_dict()['param_groups'][0]['lr'])
            print('Learning rate ',
                  self.optimiser.state_dict()['param_groups'][0]['lr'],
                  file=self.log_file)
            # dev data
            self.run_for_epoch(epoch, dev_data, dev_labels, type='Dev')

            # test data
            self.run_for_epoch(epoch, test_data, test_labels, type='Test')

            if epoch % self.network_save_interval == 0:
                save_path = self.network_save_path + '/' + self.run_name + '_' + str(
                    epoch) + '.pt'
                torch.save(self.network.state_dict(), save_path)
                print('Network successfully saved: ' + save_path)
예제 #3
0
    def run_for_epoch(self, epoch, x, y, type):
        self.network.eval()
        for m in self.network.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.track_running_stats = False
        predictions_dict = {"tp": [], "fp": [], "tn": [], "fn": []}
        overall_predictions = []

        self.test_batch_loss, self.test_batch_accuracy, self.test_batch_uar, self.test_batch_ua, self.test_batch_f1, self.test_batch_precision, self.test_batch_recall, audio_for_tensorboard_test = [], [], [], [], [], [], [], None
        with torch.no_grad():
            for i, (audio_data, label) in enumerate(zip(x, y)):
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                test_predictions = self.network(audio_data).squeeze(1)
                test_loss = self.loss_function(test_predictions, label)
                test_predictions = nn.Sigmoid()(test_predictions)
                overall_predictions.extend(to_numpy(test_predictions))
                test_accuracy, test_uar, test_precision, test_recall, test_f1 = accuracy_fn(
                    test_predictions, label, self.threshold)
                self.test_batch_loss.append(to_numpy(test_loss))
                self.test_batch_accuracy.append(to_numpy(test_accuracy))
                self.test_batch_uar.append(test_uar)
                self.test_batch_f1.append(test_f1)
                self.test_batch_precision.append(test_precision)
                self.test_batch_recall.append(test_recall)

                tp, fp, tn, fn = custom_confusion_matrix(
                    test_predictions, label, threshold=self.threshold)
                predictions_dict['tp'].extend(tp)
                predictions_dict['fp'].extend(fp)
                predictions_dict['tn'].extend(tn)
                predictions_dict['fn'].extend(fn)

        print(f'***** {type} Metrics ***** ')
        print(f'***** {type} Metrics ***** ', file=self.log_file)
        print(
            f"Loss: {'%.3f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.3f' % np.mean(self.test_batch_accuracy)} | UAR: {'%.3f' % np.mean(self.test_batch_uar)}| F1:{'%.3f' % np.mean(self.test_batch_f1)} | Precision:{'%.3f' % np.mean(self.test_batch_precision)} | Recall:{'%.3f' % np.mean(self.test_batch_recall)}"
        )
        print(
            f"Loss: {'%.3f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.3f' % np.mean(self.test_batch_accuracy)} | UAR: {'%.3f' % np.mean(self.test_batch_uar)}| F1:{'%.3f' % np.mean(self.test_batch_f1)} | Precision:{'%.3f' % np.mean(self.test_batch_precision)} | Recall:{'%.3f' % np.mean(self.test_batch_recall)}",
            file=self.log_file)

        print(
            "****************************************************************")
        print(np.array(overall_predictions).shape)
        print(f"{type} predictions mean", np.mean(overall_predictions))
        print(f"{type} predictions mean",
              np.mean(overall_predictions),
              file=self.log_file)
        print(f"{type} predictions sum", np.sum(overall_predictions))
        print(f"{type} predictions sum",
              np.sum(overall_predictions),
              file=self.log_file)
        print(f"{type} predictions range", np.min(overall_predictions),
              np.max(overall_predictions))
        print(f"{type} predictions range",
              np.min(overall_predictions),
              np.max(overall_predictions),
              file=self.log_file)
        print(f"{type} predictions hist", np.histogram(overall_predictions))
        print(f"{type} predictions hist",
              np.histogram(overall_predictions),
              file=self.log_file)
        print(f"{type} predictions variance", np.var(overall_predictions))
        print(f"{type} predictions variance",
              np.var(overall_predictions),
              file=self.log_file)
        print(
            "****************************************************************")

        log_summary(self.writer,
                    epoch,
                    accuracy=np.mean(self.test_batch_accuracy),
                    loss=np.mean(self.test_batch_loss),
                    uar=np.mean(self.test_batch_uar),
                    lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                    type=type)
        log_conf_matrix(self.writer,
                        epoch,
                        predictions_dict=predictions_dict,
                        type=type)

        y = [element for sublist in y for element in sublist]
        write_to_npy(filename=self.debug_filename,
                     predictions=overall_predictions,
                     labels=y,
                     epoch=epoch,
                     accuracy=np.mean(self.test_batch_accuracy),
                     loss=np.mean(self.test_batch_loss),
                     uar=np.mean(self.test_batch_uar),
                     lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                     predictions_dict=predictions_dict,
                     type=type)
예제 #4
0
    def train_model(self,
                    model,
                    criterion,
                    optimizer,
                    scheduler,
                    num_epochs=25):
        def calc_uar(y_true, y_pred):
            uar = recall_score(to_numpy(y_true),
                               to_numpy(y_pred),
                               average='macro')
            return uar

        since = time.time()

        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = 0.0

        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'dev', 'test']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                total_preds, total_labels = [], []
                # Iterate over data.
                for inputs, labels in self.data_loaders[phase]:
                    inputs = inputs.to(self.device)
                    labels = labels.to(self.device)
                    total_labels.extend(to_numpy(labels))

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs).squeeze(1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    binary_preds = torch.where(
                        outputs > to_tensor(self.threshold), to_tensor(1),
                        to_tensor(0))
                    total_preds.extend(to_numpy(binary_preds))
                    running_corrects += torch.sum(binary_preds == labels)
                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / self.dataset_sizes[phase]
                epoch_acc = running_corrects.double(
                ) / self.dataset_sizes[phase]
                epoch_uar = calc_uar(total_preds, total_labels)

                print('{} Loss: {:.4f} Acc: {:.4f} UAR: {}'.format(
                    phase, epoch_loss, epoch_acc, epoch_uar))

                # deep copy the model
                if phase == 'dev' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:4f}'.format(best_acc))

        # load best model weights
        model.load_state_dict(best_model_wts)
        return model
예제 #5
0
    def train(self):

        # For purposes of calculating normalized values, call this method with train data followed by test
        train_inp_file, train_out_file, train_jitter_file = 'train_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_data.npy', 'train_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_labels.npy', 'train_challenge_with_shimmer_jitter.npy'
        dev_inp_file, dev_out_file, dev_jitter_file = 'dev_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_data.npy', 'dev_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_labels.npy', 'dev_challenge_with_shimmer_jitter.npy'
        test_inp_file, test_out_file, test_jitter_file = 'test_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_data.npy', 'test_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_labels.npy', 'test_challenge_with_shimmer_jitter.npy'

        self.logger.info(
            f'Reading train file {train_inp_file, train_out_file}')
        train_data, train_labels, train_jitter_data = self.data_reader(
            self.data_read_path + train_inp_file,
            self.data_read_path + train_out_file,
            self.data_read_path + train_jitter_file,
            shuffle=True,
            train=True)
        self.logger.info(f'Reading dev file {dev_inp_file, dev_out_file}')
        dev_data, dev_labels, dev_jitter_data = self.data_reader(
            self.data_read_path + dev_inp_file,
            self.data_read_path + dev_out_file,
            self.data_read_path + dev_jitter_file,
            shuffle=False,
            train=False)
        self.logger.info(f'Reading test file {test_inp_file, test_out_file}')
        test_data, test_labels, test_jitter_data = self.data_reader(
            self.data_read_path + test_inp_file,
            self.data_read_path + test_out_file,
            self.data_read_path + test_jitter_file,
            shuffle=False,
            train=False)

        # For the purposes of assigning pos weight on the fly we are initializing the cost function here
        self.loss_function = nn.BCEWithLogitsLoss(
            pos_weight=to_tensor(self.pos_weight, device=self.device))

        total_step = len(train_data)
        for epoch in range(1, self.epochs):
            log_learnable_parameter(
                self.writer,
                epoch - 1,
                network_params=self.network.named_parameters())
            self.network.train()
            self.batch_loss, self.batch_accuracy, self.batch_uar, self.batch_f1, self.batch_precision, \
            self.batch_recall, train_predictions, train_logits, audio_for_tensorboard_train = [], [], [], [], [], [], [], [], None
            for i, (audio_data, label, jitter_shimmer_data) in enumerate(
                    zip(train_data, train_labels, train_jitter_data)):
                self.optimiser.zero_grad()
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                jitter_shimmer_data = to_tensor(jitter_shimmer_data,
                                                device=self.device)
                if i == 0:
                    self.writer.add_graph(self.network,
                                          (audio_data, jitter_shimmer_data))
                predictions = self.network(audio_data,
                                           jitter_shimmer_data).squeeze(1)
                train_logits.extend(predictions)
                loss = self.loss_function(predictions, label)
                predictions = nn.Sigmoid()(predictions)
                train_predictions.extend(predictions)
                loss.backward()
                self.optimiser.step()
                accuracy, uar, precision, recall, f1 = accuracy_fn(
                    predictions, label, self.threshold)
                self.batch_loss.append(to_numpy(loss))
                self.batch_accuracy.append(to_numpy(accuracy))
                self.batch_uar.append(uar)
                self.batch_f1.append(f1)
                self.batch_precision.append(precision)
                self.batch_recall.append(recall)

                if i % self.display_interval == 0:
                    self.logger.info(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {'%.3f' % loss} | Accuracy: {'%.3f' % accuracy} | UAR: {'%.3f' % uar}| F1:{'%.3f' % f1} | Precision: {'%.3f' % precision} | Recall: {'%.3f' % recall}"
                    )

            log_learnable_parameter(self.writer,
                                    epoch,
                                    to_tensor(train_logits,
                                              device=self.device),
                                    name='train_logits')
            log_learnable_parameter(self.writer,
                                    epoch,
                                    to_tensor(train_predictions,
                                              device=self.device),
                                    name='train_activated')

            # Decay learning rate
            self.scheduler.step(epoch=epoch)
            log_summary(
                self.writer,
                epoch,
                accuracy=np.mean(self.batch_accuracy),
                loss=np.mean(self.batch_loss),
                uar=np.mean(self.batch_uar),
                lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                type='Train')
            self.logger.info('***** Overall Train Metrics ***** ')
            self.logger.info(
                f"Loss: {'%.3f' % np.mean(self.batch_loss)} | Accuracy: {'%.3f' % np.mean(self.batch_accuracy)} | UAR: {'%.3f' % np.mean(self.batch_uar)} | F1:{'%.3f' % np.mean(self.batch_f1)} | Precision:{'%.3f' % np.mean(self.batch_precision)} | Recall:{'%.3f' % np.mean(self.batch_recall)}"
            )
            self.logger.info(
                f"Learning rate {self.optimiser.state_dict()['param_groups'][0]['lr']}"
            )

            # dev data
            self.run_for_epoch(epoch,
                               dev_data,
                               dev_labels,
                               dev_jitter_data,
                               type='Dev')

            # test data
            self.run_for_epoch(epoch,
                               test_data,
                               test_labels,
                               test_jitter_data,
                               type='Test')

            if epoch % self.network_save_interval == 0:
                save_path = self.network_save_path + '/' + self.run_name + '_' + str(
                    epoch) + '.pt'
                torch.save(self.network.state_dict(), save_path)
                self.logger.info(f'Network successfully saved: {save_path}')
예제 #6
0
    def run_for_epoch(self, epoch, x, y, jitterx, type):
        # self.network.eval()
        # for m in self.network.modules():
        #     if isinstance(m, nn.BatchNorm2d):
        #         m.track_running_stats = False
        predictions_dict = {"tp": [], "fp": [], "tn": [], "fn": []}
        logits, predictions = [], []
        self.test_batch_loss, self.test_batch_accuracy, self.test_batch_uar, self.test_batch_ua, self.test_batch_f1, self.test_batch_precision, self.test_batch_recall, audio_for_tensorboard_test = [], [], [], [], [], [], [], None
        with torch.no_grad():
            for i, (audio_data, label,
                    jitter_shimmer_data) in enumerate(zip(x, y, jitterx)):
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                jitter_shimmer_data = to_tensor(jitter_shimmer_data,
                                                device=self.device)
                test_predictions = self.network(audio_data,
                                                jitter_shimmer_data).squeeze(1)
                logits.extend(to_numpy(test_predictions))
                test_loss = self.loss_function(test_predictions, label)
                test_predictions = nn.Sigmoid()(test_predictions)
                predictions.append(to_numpy(test_predictions))
                test_accuracy, test_uar, test_precision, test_recall, test_f1 = accuracy_fn(
                    test_predictions, label, self.threshold)
                self.test_batch_loss.append(to_numpy(test_loss))
                self.test_batch_accuracy.append(to_numpy(test_accuracy))
                self.test_batch_uar.append(test_uar)
                self.test_batch_f1.append(test_f1)
                self.test_batch_precision.append(test_precision)
                self.test_batch_recall.append(test_recall)

                tp, fp, tn, fn = custom_confusion_matrix(
                    test_predictions, label, threshold=self.threshold)
                predictions_dict['tp'].extend(tp)
                predictions_dict['fp'].extend(fp)
                predictions_dict['tn'].extend(tn)
                predictions_dict['fn'].extend(fn)

        predictions = [
            element for sublist in predictions for element in sublist
        ]
        self.logger.info(f'***** {type} Metrics ***** ')
        self.logger.info(
            f"Loss: {'%.3f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.3f' % np.mean(self.test_batch_accuracy)} | UAR: {'%.3f' % np.mean(self.test_batch_uar)}| F1:{'%.3f' % np.mean(self.test_batch_f1)} | Precision:{'%.3f' % np.mean(self.test_batch_precision)} | Recall:{'%.3f' % np.mean(self.test_batch_recall)}"
        )

        log_summary(self.writer,
                    epoch,
                    accuracy=np.mean(self.test_batch_accuracy),
                    loss=np.mean(self.test_batch_loss),
                    uar=np.mean(self.test_batch_uar),
                    lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                    type=type)
        log_conf_matrix(self.writer,
                        epoch,
                        predictions_dict=predictions_dict,
                        type=type)

        log_learnable_parameter(self.writer,
                                epoch,
                                to_tensor(logits, device=self.device),
                                name=f'{type}_logits')
        log_learnable_parameter(self.writer,
                                epoch,
                                to_tensor(predictions, device=self.device),
                                name=f'{type}_predictions')

        write_to_npy(filename=self.debug_filename,
                     predictions=predictions,
                     labels=y,
                     epoch=epoch,
                     accuracy=np.mean(self.test_batch_accuracy),
                     loss=np.mean(self.test_batch_loss),
                     uar=np.mean(self.test_batch_uar),
                     lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                     predictions_dict=predictions_dict,
                     type=type)