Exemple #1
0
    def run_for_epoch(self, epoch, x, y, type):
        self.network.eval()
        predictions_dict = {"tp": [], "fp": [], "tn": [], "fn": []}
        predictions = []
        self.test_batch_loss, self.test_batch_accuracy, self.test_batch_uar, self.test_batch_ua, audio_for_tensorboard_test = [], [], [], [], None
        with torch.no_grad():
            for i, (audio_data, label) in enumerate(zip(x, y)):
                label = tensor(label).float()
                test_predictions = self.network(audio_data).squeeze(1)
                test_loss = self.loss_function(test_predictions, label)
                test_predictions = nn.Sigmoid()(test_predictions)
                predictions.append(test_predictions.numpy())
                test_accuracy, test_uar = accuracy_fn(test_predictions, label,
                                                      self.threshold)
                self.test_batch_loss.append(test_loss.numpy())
                self.test_batch_accuracy.append(test_accuracy.numpy())
                self.test_batch_uar.append(test_uar)
                tp, fp, tn, fn = custom_confusion_matrix(
                    test_predictions, label, threshold=self.threshold)
                predictions_dict['tp'].extend(tp)
                predictions_dict['fp'].extend(fp)
                predictions_dict['tn'].extend(tn)
                predictions_dict['fn'].extend(fn)

        print(f'***** {type} Metrics ***** ')
        print(f'***** {type} Metrics ***** ', file=self.log_file)
        print(
            f"Loss: {np.mean(self.test_batch_loss)} | Accuracy: {np.mean(self.test_batch_accuracy)} | UAR: {np.mean(self.test_batch_uar)}"
        )
        print(
            f"Loss: {np.mean(self.test_batch_loss)} | Accuracy: {np.mean(self.test_batch_accuracy)} | UAR: {np.mean(self.test_batch_uar)}",
            file=self.log_file)

        log_summary(self.writer,
                    epoch,
                    accuracy=np.mean(self.test_batch_accuracy),
                    loss=np.mean(self.test_batch_loss),
                    uar=np.mean(self.test_batch_uar),
                    lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                    type=type)
        log_conf_matrix(self.writer,
                        epoch,
                        predictions_dict=predictions_dict,
                        type=type)

        y = [element for sublist in y for element in sublist]
        predictions = [
            element for sublist in predictions for element in sublist
        ]
        write_to_npy(filename=self.debug_filename,
                     predictions=predictions,
                     labels=y,
                     epoch=epoch,
                     accuracy=np.mean(self.test_batch_accuracy),
                     loss=np.mean(self.test_batch_loss),
                     uar=np.mean(self.test_batch_uar),
                     lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                     predictions_dict=predictions_dict,
                     type=type)
    def train(self):

        # For purposes of calculating normalized values, call this method with train data followed by test
        train_data, train_labels = self.data_reader(
            self.data_read_path +
            'train_challenge_with_d1_raw_16k_data_4sec.npy',
            self.data_read_path +
            'train_challenge_with_d1_raw_16k_labels_4sec.npy',
            shuffle=True,
            train=True)
        dev_data, dev_labels = self.data_reader(
            self.data_read_path +
            'dev_challenge_with_d1_raw_16k_data_4sec.npy',
            self.data_read_path +
            'dev_challenge_with_d1_raw_16k_labels_4sec.npy',
            shuffle=False,
            train=False)
        test_data, test_labels = self.data_reader(
            self.data_read_path +
            'test_challenge_with_d1_raw_16k_data_4sec.npy',
            self.data_read_path +
            'test_challenge_with_d1_raw_16k_labels_4sec.npy',
            shuffle=False,
            train=False)

        # For the purposes of assigning pos weight on the fly we are initializing the cost function here
        self.loss_function = nn.BCEWithLogitsLoss(
            pos_weight=to_tensor(self.pos_weight, device=self.device))

        total_step = len(train_data)
        for epoch in range(1, self.epochs):
            self.network.train()
            self.batch_loss, self.batch_accuracy, self.batch_uar, self.batch_f1, self.batch_precision, self.batch_recall, audio_for_tensorboard_train = [], [], [], [], [], [], None
            for i, (audio_data,
                    label) in enumerate(zip(train_data, train_labels)):
                self.optimiser.zero_grad()
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                # if i == 0:
                #     self.writer.add_graph(self.network, audio_data)
                predictions = self.network(audio_data).squeeze(1)
                loss = self.loss_function(predictions, label)
                predictions = nn.Sigmoid()(predictions)
                loss.backward()
                self.optimiser.step()
                accuracy, uar, precision, recall, f1 = accuracy_fn(
                    predictions, label, self.threshold)
                self.batch_loss.append(to_numpy(loss))
                self.batch_accuracy.append(to_numpy(accuracy))
                self.batch_uar.append(uar)
                self.batch_f1.append(f1)
                self.batch_precision.append(precision)
                self.batch_recall.append(recall)

                if i % self.display_interval == 0:
                    print(
                        "****************************************************************"
                    )
                    print("predictions mean",
                          torch.mean(predictions).detach().cpu().numpy())
                    print("predictions mean",
                          torch.mean(predictions).detach().cpu().numpy(),
                          file=self.log_file)
                    print("predictions sum",
                          torch.sum(predictions).detach().cpu().numpy())
                    print("predictions sum",
                          torch.sum(predictions).detach().cpu().numpy(),
                          file=self.log_file)
                    print("predictions range",
                          torch.min(predictions).detach().cpu().numpy(),
                          torch.max(predictions).detach().cpu().numpy())
                    print("predictions range",
                          torch.min(predictions).detach().cpu().numpy(),
                          torch.max(predictions).detach().cpu().numpy(),
                          file=self.log_file)
                    print("predictions hist",
                          np.histogram(predictions.detach().cpu().numpy()))
                    print("predictions hist",
                          np.histogram(predictions.detach().cpu().numpy()),
                          file=self.log_file)
                    print("predictions variance",
                          torch.var(predictions).detach().cpu().numpy())
                    print("predictions variance",
                          torch.var(predictions).detach().cpu().numpy(),
                          file=self.log_file)
                    print(
                        "****************************************************************"
                    )
                    print(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {'%.3f' % loss} | Accuracy: {'%.3f' % accuracy} | UAR: {'%.3f' % uar}| F1:{'%.3f' % f1} | Precision: {'%.3f' % precision} | Recall: {'%.3f' % recall}"
                    )
                    print(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {'%.3f' % loss} | Accuracy: {accuracy} | UAR: {'%.3f' % uar}| F1:{'%.3f' % f1} | Precision: {'%.3f' % precision} | Recall: {'%.3f' % recall}",
                        file=self.log_file)

            # Decay learning rate
            # self.scheduler.step(epoch=epoch)
            log_summary(
                self.writer,
                epoch,
                accuracy=np.mean(self.batch_accuracy),
                loss=np.mean(self.batch_loss),
                uar=np.mean(self.batch_uar),
                lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                type='Train')
            print('***** Overall Train Metrics ***** ')
            print('***** Overall Train Metrics ***** ', file=self.log_file)
            print(
                f"Loss: {'%.3f' % np.mean(self.batch_loss)} | Accuracy: {'%.3f' % np.mean(self.batch_accuracy)} | UAR: {'%.3f' % np.mean(self.batch_uar)} | F1:{'%.3f' % np.mean(self.batch_f1)} | Precision:{'%.3f' % np.mean(self.batch_precision)} | Recall:{'%.3f' % np.mean(self.batch_recall)}"
            )
            print(
                f"Loss: {'%.3f' % np.mean(self.batch_loss)} | Accuracy: {'%.3f' % np.mean(self.batch_accuracy)} | UAR: {'%.3f' % np.mean(self.batch_uar)} | F1:{'%.3f' % np.mean(self.batch_f1)} | Precision:{'%.3f' % np.mean(self.batch_precision)} | Recall:{'%.3f' % np.mean(self.batch_recall)}",
                file=self.log_file)
            print('Learning rate ',
                  self.optimiser.state_dict()['param_groups'][0]['lr'])
            print('Learning rate ',
                  self.optimiser.state_dict()['param_groups'][0]['lr'],
                  file=self.log_file)
            # dev data
            self.run_for_epoch(epoch, dev_data, dev_labels, type='Dev')

            # test data
            self.run_for_epoch(epoch, test_data, test_labels, type='Test')

            if epoch % self.network_save_interval == 0:
                save_path = self.network_save_path + '/' + self.run_name + '_' + str(
                    epoch) + '.pt'
                torch.save(self.network.state_dict(), save_path)
                print('Network successfully saved: ' + save_path)
    def run_for_epoch(self, epoch, x, y, type):
        self.network.eval()
        for m in self.network.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.track_running_stats = False
        predictions_dict = {"tp": [], "fp": [], "tn": [], "fn": []}
        overall_predictions = []

        self.test_batch_loss, self.test_batch_accuracy, self.test_batch_uar, self.test_batch_ua, self.test_batch_f1, self.test_batch_precision, self.test_batch_recall, audio_for_tensorboard_test = [], [], [], [], [], [], [], None
        with torch.no_grad():
            for i, (audio_data, label) in enumerate(zip(x, y)):
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                test_predictions = self.network(audio_data).squeeze(1)
                test_loss = self.loss_function(test_predictions, label)
                test_predictions = nn.Sigmoid()(test_predictions)
                overall_predictions.extend(to_numpy(test_predictions))
                test_accuracy, test_uar, test_precision, test_recall, test_f1 = accuracy_fn(
                    test_predictions, label, self.threshold)
                self.test_batch_loss.append(to_numpy(test_loss))
                self.test_batch_accuracy.append(to_numpy(test_accuracy))
                self.test_batch_uar.append(test_uar)
                self.test_batch_f1.append(test_f1)
                self.test_batch_precision.append(test_precision)
                self.test_batch_recall.append(test_recall)

                tp, fp, tn, fn = custom_confusion_matrix(
                    test_predictions, label, threshold=self.threshold)
                predictions_dict['tp'].extend(tp)
                predictions_dict['fp'].extend(fp)
                predictions_dict['tn'].extend(tn)
                predictions_dict['fn'].extend(fn)

        print(f'***** {type} Metrics ***** ')
        print(f'***** {type} Metrics ***** ', file=self.log_file)
        print(
            f"Loss: {'%.3f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.3f' % np.mean(self.test_batch_accuracy)} | UAR: {'%.3f' % np.mean(self.test_batch_uar)}| F1:{'%.3f' % np.mean(self.test_batch_f1)} | Precision:{'%.3f' % np.mean(self.test_batch_precision)} | Recall:{'%.3f' % np.mean(self.test_batch_recall)}"
        )
        print(
            f"Loss: {'%.3f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.3f' % np.mean(self.test_batch_accuracy)} | UAR: {'%.3f' % np.mean(self.test_batch_uar)}| F1:{'%.3f' % np.mean(self.test_batch_f1)} | Precision:{'%.3f' % np.mean(self.test_batch_precision)} | Recall:{'%.3f' % np.mean(self.test_batch_recall)}",
            file=self.log_file)

        print(
            "****************************************************************")
        print(np.array(overall_predictions).shape)
        print(f"{type} predictions mean", np.mean(overall_predictions))
        print(f"{type} predictions mean",
              np.mean(overall_predictions),
              file=self.log_file)
        print(f"{type} predictions sum", np.sum(overall_predictions))
        print(f"{type} predictions sum",
              np.sum(overall_predictions),
              file=self.log_file)
        print(f"{type} predictions range", np.min(overall_predictions),
              np.max(overall_predictions))
        print(f"{type} predictions range",
              np.min(overall_predictions),
              np.max(overall_predictions),
              file=self.log_file)
        print(f"{type} predictions hist", np.histogram(overall_predictions))
        print(f"{type} predictions hist",
              np.histogram(overall_predictions),
              file=self.log_file)
        print(f"{type} predictions variance", np.var(overall_predictions))
        print(f"{type} predictions variance",
              np.var(overall_predictions),
              file=self.log_file)
        print(
            "****************************************************************")

        log_summary(self.writer,
                    epoch,
                    accuracy=np.mean(self.test_batch_accuracy),
                    loss=np.mean(self.test_batch_loss),
                    uar=np.mean(self.test_batch_uar),
                    lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                    type=type)
        log_conf_matrix(self.writer,
                        epoch,
                        predictions_dict=predictions_dict,
                        type=type)

        y = [element for sublist in y for element in sublist]
        write_to_npy(filename=self.debug_filename,
                     predictions=overall_predictions,
                     labels=y,
                     epoch=epoch,
                     accuracy=np.mean(self.test_batch_accuracy),
                     loss=np.mean(self.test_batch_loss),
                     uar=np.mean(self.test_batch_uar),
                     lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                     predictions_dict=predictions_dict,
                     type=type)
    def train(self):

        # For purposes of calculating normalized values, call this method with train data followed by test
        train_inp_file, train_out_file, train_jitter_file = 'train_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_data.npy', 'train_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_labels.npy', 'train_challenge_with_shimmer_jitter.npy'
        dev_inp_file, dev_out_file, dev_jitter_file = 'dev_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_data.npy', 'dev_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_labels.npy', 'dev_challenge_with_shimmer_jitter.npy'
        test_inp_file, test_out_file, test_jitter_file = 'test_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_data.npy', 'test_challenge_with_d1_mel_power_to_db_fnot_zr_crossing_labels.npy', 'test_challenge_with_shimmer_jitter.npy'

        self.logger.info(
            f'Reading train file {train_inp_file, train_out_file}')
        train_data, train_labels, train_jitter_data = self.data_reader(
            self.data_read_path + train_inp_file,
            self.data_read_path + train_out_file,
            self.data_read_path + train_jitter_file,
            shuffle=True,
            train=True)
        self.logger.info(f'Reading dev file {dev_inp_file, dev_out_file}')
        dev_data, dev_labels, dev_jitter_data = self.data_reader(
            self.data_read_path + dev_inp_file,
            self.data_read_path + dev_out_file,
            self.data_read_path + dev_jitter_file,
            shuffle=False,
            train=False)
        self.logger.info(f'Reading test file {test_inp_file, test_out_file}')
        test_data, test_labels, test_jitter_data = self.data_reader(
            self.data_read_path + test_inp_file,
            self.data_read_path + test_out_file,
            self.data_read_path + test_jitter_file,
            shuffle=False,
            train=False)

        # For the purposes of assigning pos weight on the fly we are initializing the cost function here
        self.loss_function = nn.BCEWithLogitsLoss(
            pos_weight=to_tensor(self.pos_weight, device=self.device))

        total_step = len(train_data)
        for epoch in range(1, self.epochs):
            log_learnable_parameter(
                self.writer,
                epoch - 1,
                network_params=self.network.named_parameters())
            self.network.train()
            self.batch_loss, self.batch_accuracy, self.batch_uar, self.batch_f1, self.batch_precision, \
            self.batch_recall, train_predictions, train_logits, audio_for_tensorboard_train = [], [], [], [], [], [], [], [], None
            for i, (audio_data, label, jitter_shimmer_data) in enumerate(
                    zip(train_data, train_labels, train_jitter_data)):
                self.optimiser.zero_grad()
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                jitter_shimmer_data = to_tensor(jitter_shimmer_data,
                                                device=self.device)
                if i == 0:
                    self.writer.add_graph(self.network,
                                          (audio_data, jitter_shimmer_data))
                predictions = self.network(audio_data,
                                           jitter_shimmer_data).squeeze(1)
                train_logits.extend(predictions)
                loss = self.loss_function(predictions, label)
                predictions = nn.Sigmoid()(predictions)
                train_predictions.extend(predictions)
                loss.backward()
                self.optimiser.step()
                accuracy, uar, precision, recall, f1 = accuracy_fn(
                    predictions, label, self.threshold)
                self.batch_loss.append(to_numpy(loss))
                self.batch_accuracy.append(to_numpy(accuracy))
                self.batch_uar.append(uar)
                self.batch_f1.append(f1)
                self.batch_precision.append(precision)
                self.batch_recall.append(recall)

                if i % self.display_interval == 0:
                    self.logger.info(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {'%.3f' % loss} | Accuracy: {'%.3f' % accuracy} | UAR: {'%.3f' % uar}| F1:{'%.3f' % f1} | Precision: {'%.3f' % precision} | Recall: {'%.3f' % recall}"
                    )

            log_learnable_parameter(self.writer,
                                    epoch,
                                    to_tensor(train_logits,
                                              device=self.device),
                                    name='train_logits')
            log_learnable_parameter(self.writer,
                                    epoch,
                                    to_tensor(train_predictions,
                                              device=self.device),
                                    name='train_activated')

            # Decay learning rate
            self.scheduler.step(epoch=epoch)
            log_summary(
                self.writer,
                epoch,
                accuracy=np.mean(self.batch_accuracy),
                loss=np.mean(self.batch_loss),
                uar=np.mean(self.batch_uar),
                lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                type='Train')
            self.logger.info('***** Overall Train Metrics ***** ')
            self.logger.info(
                f"Loss: {'%.3f' % np.mean(self.batch_loss)} | Accuracy: {'%.3f' % np.mean(self.batch_accuracy)} | UAR: {'%.3f' % np.mean(self.batch_uar)} | F1:{'%.3f' % np.mean(self.batch_f1)} | Precision:{'%.3f' % np.mean(self.batch_precision)} | Recall:{'%.3f' % np.mean(self.batch_recall)}"
            )
            self.logger.info(
                f"Learning rate {self.optimiser.state_dict()['param_groups'][0]['lr']}"
            )

            # dev data
            self.run_for_epoch(epoch,
                               dev_data,
                               dev_labels,
                               dev_jitter_data,
                               type='Dev')

            # test data
            self.run_for_epoch(epoch,
                               test_data,
                               test_labels,
                               test_jitter_data,
                               type='Test')

            if epoch % self.network_save_interval == 0:
                save_path = self.network_save_path + '/' + self.run_name + '_' + str(
                    epoch) + '.pt'
                torch.save(self.network.state_dict(), save_path)
                self.logger.info(f'Network successfully saved: {save_path}')
    def run_for_epoch(self, epoch, x, y, jitterx, type):
        # self.network.eval()
        # for m in self.network.modules():
        #     if isinstance(m, nn.BatchNorm2d):
        #         m.track_running_stats = False
        predictions_dict = {"tp": [], "fp": [], "tn": [], "fn": []}
        logits, predictions = [], []
        self.test_batch_loss, self.test_batch_accuracy, self.test_batch_uar, self.test_batch_ua, self.test_batch_f1, self.test_batch_precision, self.test_batch_recall, audio_for_tensorboard_test = [], [], [], [], [], [], [], None
        with torch.no_grad():
            for i, (audio_data, label,
                    jitter_shimmer_data) in enumerate(zip(x, y, jitterx)):
                label = to_tensor(label, device=self.device).float()
                audio_data = to_tensor(audio_data, device=self.device)
                jitter_shimmer_data = to_tensor(jitter_shimmer_data,
                                                device=self.device)
                test_predictions = self.network(audio_data,
                                                jitter_shimmer_data).squeeze(1)
                logits.extend(to_numpy(test_predictions))
                test_loss = self.loss_function(test_predictions, label)
                test_predictions = nn.Sigmoid()(test_predictions)
                predictions.append(to_numpy(test_predictions))
                test_accuracy, test_uar, test_precision, test_recall, test_f1 = accuracy_fn(
                    test_predictions, label, self.threshold)
                self.test_batch_loss.append(to_numpy(test_loss))
                self.test_batch_accuracy.append(to_numpy(test_accuracy))
                self.test_batch_uar.append(test_uar)
                self.test_batch_f1.append(test_f1)
                self.test_batch_precision.append(test_precision)
                self.test_batch_recall.append(test_recall)

                tp, fp, tn, fn = custom_confusion_matrix(
                    test_predictions, label, threshold=self.threshold)
                predictions_dict['tp'].extend(tp)
                predictions_dict['fp'].extend(fp)
                predictions_dict['tn'].extend(tn)
                predictions_dict['fn'].extend(fn)

        predictions = [
            element for sublist in predictions for element in sublist
        ]
        self.logger.info(f'***** {type} Metrics ***** ')
        self.logger.info(
            f"Loss: {'%.3f' % np.mean(self.test_batch_loss)} | Accuracy: {'%.3f' % np.mean(self.test_batch_accuracy)} | UAR: {'%.3f' % np.mean(self.test_batch_uar)}| F1:{'%.3f' % np.mean(self.test_batch_f1)} | Precision:{'%.3f' % np.mean(self.test_batch_precision)} | Recall:{'%.3f' % np.mean(self.test_batch_recall)}"
        )

        log_summary(self.writer,
                    epoch,
                    accuracy=np.mean(self.test_batch_accuracy),
                    loss=np.mean(self.test_batch_loss),
                    uar=np.mean(self.test_batch_uar),
                    lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                    type=type)
        log_conf_matrix(self.writer,
                        epoch,
                        predictions_dict=predictions_dict,
                        type=type)

        log_learnable_parameter(self.writer,
                                epoch,
                                to_tensor(logits, device=self.device),
                                name=f'{type}_logits')
        log_learnable_parameter(self.writer,
                                epoch,
                                to_tensor(predictions, device=self.device),
                                name=f'{type}_predictions')

        write_to_npy(filename=self.debug_filename,
                     predictions=predictions,
                     labels=y,
                     epoch=epoch,
                     accuracy=np.mean(self.test_batch_accuracy),
                     loss=np.mean(self.test_batch_loss),
                     uar=np.mean(self.test_batch_uar),
                     lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                     predictions_dict=predictions_dict,
                     type=type)
Exemple #6
0
    def train(self):

        # For purposes of calculating normalized values, call this method with train data followed by test
        train_data, train_labels = self.data_reader(
            self.data_read_path + 'train_challenge_with_d1_data.npy',
            self.data_read_path + 'train_challenge_with_d1_labels.npy',
            shuffle=True,
            train=True)
        dev_data, dev_labels = self.data_reader(
            self.data_read_path + 'dev_challenge_with_d1_data.npy',
            self.data_read_path + 'dev_challenge_with_d1_labels.npy',
            shuffle=False,
            train=False)
        test_data, test_labels = self.data_reader(
            self.data_read_path + 'test_challenge_data.npy',
            self.data_read_path + 'test_challenge_labels.npy',
            shuffle=False,
            train=False)

        # For the purposes of assigning pos weight on the fly we are initializing the cost function here
        self.loss_function = nn.BCEWithLogitsLoss(
            pos_weight=tensor(self.pos_weight))

        total_step = len(train_data)
        for epoch in range(1, self.epochs):
            self.network.train()
            self.batch_loss, self.batch_accuracy, self.batch_uar, audio_for_tensorboard_train = [], [], [], None
            for i, (audio_data,
                    label) in enumerate(zip(train_data, train_labels)):
                self.optimiser.zero_grad()
                label = tensor(label).float()
                if i == 0:
                    self.writer.add_graph(self.network, tensor(audio_data))
                predictions = self.network(audio_data).squeeze(1)
                loss = self.loss_function(predictions, label)
                loss.backward()
                self.optimiser.step()
                predictions = nn.Sigmoid()(predictions)
                accuracy, uar = accuracy_fn(predictions, label, self.threshold)
                self.batch_loss.append(loss.detach().numpy())
                self.batch_accuracy.append(accuracy)
                self.batch_uar.append(uar)

                if i % self.display_interval == 0:
                    print(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {loss} | Accuracy: {accuracy} | UAR: {uar}"
                    )
                    print(
                        f"Epoch: {epoch}/{self.epochs} | Step: {i}/{total_step} | Loss: {loss} | Accuracy: {accuracy} | UAR: {uar}",
                        file=self.log_file)
            # Decay learning rate
            self.scheduler.step(epoch=epoch)
            log_summary(
                self.writer,
                epoch,
                accuracy=np.mean(self.batch_accuracy),
                loss=np.mean(self.batch_loss),
                uar=np.mean(self.batch_uar),
                lr=self.optimiser.state_dict()['param_groups'][0]['lr'],
                type='Train')
            print('***** Overall Train Metrics ***** ')
            print('***** Overall Train Metrics ***** ', file=self.log_file)
            print(
                f"Loss: {np.mean(self.batch_loss)} | Accuracy: {np.mean(self.batch_accuracy)} | UAR: {np.mean(self.batch_uar)} "
            )
            print(
                f"Loss: {np.mean(self.batch_loss)} | Accuracy: {np.mean(self.batch_accuracy)} | UAR: {np.mean(self.batch_uar)} ",
                file=self.log_file)
            print('Learning rate ',
                  self.optimiser.state_dict()['param_groups'][0]['lr'],
                  file=self.log_file)

            # dev data
            self.run_for_epoch(epoch, dev_data, dev_labels, type='Dev')

            # test data
            self.run_for_epoch(epoch, test_data, test_labels, type='Test')

            if epoch % self.network_save_interval == 0:
                save_path = self.network_save_path + '/' + self.run_name + '_' + str(
                    epoch) + '.pt'
                torch.save(self.network.state_dict(), save_path)
                print('Network successfully saved: ' + save_path)