def train(self):

        self.print_args()
        self.audio_dumps_folder = os.path.join(self.audio_dumps_path,
                                               self.workname, 'train')
        create_folder(self.audio_dumps_folder)
        self.visual_dumps_folder = os.path.join(self.visual_dumps_path,
                                                self.workname, 'train')
        create_folder(self.visual_dumps_folder)

        training_data = UnetInput('train')
        self.train_loader = torch.utils.data.DataLoader(training_data,
                                                        batch_size=BATCH_SIZE,
                                                        shuffle=True,
                                                        num_workers=10)

        validation_data = UnetInput('val')
        self.val_loader = torch.utils.data.DataLoader(validation_data,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=True,
                                                      num_workers=10)
        for self.epoch in range(self.start_epoch, self.EPOCHS):
            with train(self):
                self.run_epoch(self.train_iter_logger)
            self.scheduler.step(self.loss)
            with val(self):
                self.run_epoch()
            stop = self.EarlyStopChecker.check_improvement(
                self.loss_.data.tuple['val'].epoch_array.val, self.epoch)
            if stop:
                print('Early Stopping Epoch : [{0}], '
                      'Best Checkpoint Epoch : [{1}]'.format(
                          self.epoch, self.EarlyStopChecker.best_epoch))
                break
Exemplo n.º 2
0
    def train(self):

        self.print_args()
        self.audio_dumps_folder = os.path.join(self.audio_dumps_path, self.workname, 'train')
        create_folder(self.audio_dumps_folder)
        self.visual_dumps_folder = os.path.join(self.visual_dumps_path, self.workname, 'train')
        create_folder(self.visual_dumps_folder)

        training_data = UnetInput('train')
        self.train_loader = torch.utils.data.DataLoader(training_data,
                                                        batch_size=BATCH_SIZE,
                                                        shuffle=True,
                                                        num_workers=10)
        self.train_batches = len(self.train_loader)
        validation_data = UnetInput('val')
        self.val_loader = torch.utils.data.DataLoader(validation_data,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=True,
                                                      num_workers=10)
        self.val_batches = len(self.val_loader)

        self.avg_cost = np.zeros([self.EPOCHS, self.K], dtype=np.float32)
        self.lambda_weight = np.ones([len(SOURCES_SUBSET), self.EPOCHS])
        for self.epoch in range(self.start_epoch, self.EPOCHS):
            self.cost = list(torch.zeros(self.K))

            # apply Dynamic Weight Average
            if self.epoch == 0 or self.epoch == 1:
                self.lambda_weight[:, self.epoch] = 1.0
            else:
                if K == 2:
                    self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0]
                    self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1]
                    exp_sum = (np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T))
                    self.lambda_weight[0, self.epoch] = 2 * np.exp(self.w_1 / self.DWA_T) / exp_sum
                    self.lambda_weight[1, self.epoch] = 2 * np.exp(self.w_2 / self.DWA_T) / exp_sum
                elif K == 4:
                    self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0]
                    self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1]
                    self.w_3 = self.avg_cost[self.epoch - 1, 2] / self.avg_cost[self.epoch - 2, 2]
                    self.w_4 = self.avg_cost[self.epoch - 1, 3] / self.avg_cost[self.epoch - 2, 3]
                    exp_sum = np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T) + np.exp(
                        self.w_3 / self.DWA_T) + np.exp(self.w_4 / self.DWA_T)
                    self.lambda_weight[0, self.epoch] = 4 * np.exp(self.w_1 / self.DWA_T) / exp_sum
                    self.lambda_weight[1, self.epoch] = 4 * np.exp(self.w_2 / self.DWA_T) / exp_sum
                    self.lambda_weight[2, self.epoch] = 4 * np.exp(self.w_3 / self.DWA_T) / exp_sum
                    self.lambda_weight[3, self.epoch] = 4 * np.exp(self.w_4 / self.DWA_T) / exp_sum

            with train(self):
                self.run_epoch(self.train_iter_logger)
            self.scheduler.step(self.loss)
            with val(self):
                self.run_epoch()

            stop = self.EarlyStopChecker.check_improvement(self.loss_.data.tuple['val'].epoch_array.val,
                                                           self.epoch)
            if stop:
                print('Early Stopping Epoch : [{0}]'.format(self.epoch))
                break
            print(
                'Epoch: {:04d} | TRAIN: {:.4f} {:.4f}'.format(self.epoch,
                                                              self.avg_cost[self.epoch, 0],
                                                              self.avg_cost[self.epoch, 1]))