def train(self): self.print_args() self.audio_dumps_folder = os.path.join(self.audio_dumps_path, self.workname, 'train') create_folder(self.audio_dumps_folder) self.visual_dumps_folder = os.path.join(self.visual_dumps_path, self.workname, 'train') create_folder(self.visual_dumps_folder) training_data = UnetInput('train') self.train_loader = torch.utils.data.DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10) validation_data = UnetInput('val') self.val_loader = torch.utils.data.DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10) for self.epoch in range(self.start_epoch, self.EPOCHS): with train(self): self.run_epoch(self.train_iter_logger) self.scheduler.step(self.loss) with val(self): self.run_epoch() stop = self.EarlyStopChecker.check_improvement( self.loss_.data.tuple['val'].epoch_array.val, self.epoch) if stop: print('Early Stopping Epoch : [{0}], ' 'Best Checkpoint Epoch : [{1}]'.format( self.epoch, self.EarlyStopChecker.best_epoch)) break
def train(self): self.print_args() self.audio_dumps_folder = os.path.join(self.audio_dumps_path, self.workname, 'train') create_folder(self.audio_dumps_folder) self.visual_dumps_folder = os.path.join(self.visual_dumps_path, self.workname, 'train') create_folder(self.visual_dumps_folder) training_data = UnetInput('train') self.train_loader = torch.utils.data.DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10) self.train_batches = len(self.train_loader) validation_data = UnetInput('val') self.val_loader = torch.utils.data.DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10) self.val_batches = len(self.val_loader) self.avg_cost = np.zeros([self.EPOCHS, self.K], dtype=np.float32) self.lambda_weight = np.ones([len(SOURCES_SUBSET), self.EPOCHS]) for self.epoch in range(self.start_epoch, self.EPOCHS): self.cost = list(torch.zeros(self.K)) # apply Dynamic Weight Average if self.epoch == 0 or self.epoch == 1: self.lambda_weight[:, self.epoch] = 1.0 else: if K == 2: self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0] self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1] exp_sum = (np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T)) self.lambda_weight[0, self.epoch] = 2 * np.exp(self.w_1 / self.DWA_T) / exp_sum self.lambda_weight[1, self.epoch] = 2 * np.exp(self.w_2 / self.DWA_T) / exp_sum elif K == 4: self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0] self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1] self.w_3 = self.avg_cost[self.epoch - 1, 2] / self.avg_cost[self.epoch - 2, 2] self.w_4 = self.avg_cost[self.epoch - 1, 3] / self.avg_cost[self.epoch - 2, 3] exp_sum = np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T) + np.exp( self.w_3 / self.DWA_T) + np.exp(self.w_4 / self.DWA_T) self.lambda_weight[0, self.epoch] = 4 * np.exp(self.w_1 / self.DWA_T) / exp_sum self.lambda_weight[1, self.epoch] = 4 * np.exp(self.w_2 / self.DWA_T) / exp_sum self.lambda_weight[2, self.epoch] = 4 * np.exp(self.w_3 / self.DWA_T) / exp_sum self.lambda_weight[3, self.epoch] = 4 * np.exp(self.w_4 / self.DWA_T) / exp_sum with train(self): self.run_epoch(self.train_iter_logger) self.scheduler.step(self.loss) with val(self): self.run_epoch() stop = self.EarlyStopChecker.check_improvement(self.loss_.data.tuple['val'].epoch_array.val, self.epoch) if stop: print('Early Stopping Epoch : [{0}]'.format(self.epoch)) break print( 'Epoch: {:04d} | TRAIN: {:.4f} {:.4f}'.format(self.epoch, self.avg_cost[self.epoch, 0], self.avg_cost[self.epoch, 1]))