def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False): super(SpecChannelUnetNoMask, self).__init__(model, rootdir, workname, main_device, trackgrad) self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio') self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals') self.audio_dumps_folder = None self.visual_dumps_folder = None self.main_device = main_device self.grid_unwarp = torch.from_numpy( warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device) self.set_tensor_scalar_item('l1') self.set_tensor_scalar_item('l2') if K == 4: self.set_tensor_scalar_item('l3') self.set_tensor_scalar_item('l4') self.set_tensor_scalar_item('loss_tracker') self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE) self.val_iterations = 0
def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False): super(Baseline, self).__init__(model, rootdir, workname, main_device, trackgrad) self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio') self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals') self.main_device=main_device self.grid_unwarp = torch.from_numpy( warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device) self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE) self.val_iterations = 0
def train_model(model, batch_size, patience, n_epochs, gpu, plotter_train, plotter_eval): # to track the training loss as the model trains train_losses = [] # to track the validation loss as the model trains valid_losses = [] # to track the average training loss per epoch as the model trains avg_train_losses = [] # to track the average validation loss per epoch as the model trains avg_valid_losses = [] # initialize the early_stopping object early_stopping = EarlyStopping(patience=patience, verbose=True) for epoch in range(1, n_epochs + 1): ################### # train the model # ################### #schedule LR #scheduler.step() model.train() # prep model for training t = tqdm(iter(train_loader), desc="[Train on Epoch {}/{}]".format(epoch, n_epochs)) for i, example in enumerate(t): #start at index 0 # get the inputs data = example["image"] #print("data size: {}".format(data.size())) target = example["fixations"] #print("target size: {}".format(target.size())) # clear the gradients of all optimized variables #optimizer_adam.zero_grad() #optimizer_sgd.zero_grad() optimizer.zero_grad() #push data and targets to gpu if gpu: if torch.cuda.is_available(): data = data.to('cuda') target = target.to('cuda') # forward pass: compute predicted outputs by passing inputs to the model output = model(data) #drop channel-dimension (is only 1) so that outputs will be of same size as targets (batch_size,100,100) #infer batch dimension as last batch won't have the full size of eg 128 output = output.view(-1, target.size()[-1], target.size()[-2]) #print("output size: {}".format(output.size())) # calculate the loss #loss = myLoss(output, target) loss = criterion(output, target) # backward pass: compute gradient of the loss with respect to model parameters loss.backward() #perform a single optimization step (parameter update) #On first 200 iterations (1. epoch), only update adam parametes (non center bias). #Afterwards, alternate between the two optimizers: adam on even iterations, sgd on #uneven iterations #if (epoch == 1) & (i < 200): # optimizer_adam.step() #elif i % 2 == 0: # optimizer_adam.step() #else: # optimizer_sgd.step() optimizer.step() # record training loss train_losses.append(loss.item()) #print("On iteration {} loss is {:.3f}".format(i+1, loss.item())) #for the first epoch, plot loss per iteration to have a quick overview of the early training phase iteration = i + 1 #plot is always appending the newest value, so just give the last item if the list if epoch == 1: plotter_train.plot('loss', 'train', 'Loss per Iteration', iteration, train_losses[-1], batch_size, lr, 'iteration') ###################### # validate the model # ###################### model.eval() # prep model for evaluation t = tqdm(iter(val_loader), desc="[Valid on Epoch {}/{}]".format(epoch, n_epochs)) for i, example in enumerate(t): # get the inputs data = example["image"] #print("input sum: {}".format(torch.sum(data))) target = example["fixations"] #push data and targets to gpu if gpu: if torch.cuda.is_available(): data = data.to('cuda') target = target.to('cuda') #print("target sum: {}".format(torch.sum(target))) # forward pass: compute predicted outputs by passing inputs to the model output = model(data) #drop channel-dimension (is only 1) so that outputs will be of same size as targets (batch_size,100,100) output = output.view(-1, target.size()[-2], target.size()[-2]) #print("output sum: {}".format(torch.sum(output))) # calculate the loss #loss = myLoss(output, target) loss = criterion(output, target) # record validation loss valid_losses.append(loss.item()) #plotter_val.plot('loss', 'val', 'Loss per Iteration', iteration, valid_losses[-1]) # print training/validation statistics # calculate average loss over an epoch train_loss = np.average(train_losses) valid_loss = np.average(valid_losses) avg_train_losses.append(train_loss) avg_valid_losses.append(valid_loss) epoch_len = str(n_epochs) print_msg = ('[{}/{}] '.format(epoch, epoch_len) + 'train_loss: {:.5f} '.format(train_loss) + 'valid_loss: {:.5f}'.format(valid_loss)) print(print_msg) #plot average loss for this epoch plotter_eval.plot('loss', 'train', 'Loss per Epoch', epoch, train_loss, batch_size, lr, 'epoch') plotter_eval.plot('loss', 'val', 'Loss per Epoch', epoch, valid_loss, batch_size, lr, 'epoch') # clear lists to track next epoch train_losses = [] valid_losses = [] # early_stopping needs the validation loss to check if it has decresed, # and if it has, it will make a checkpoint of the current model early_stopping(valid_loss, model, batch_size, lr) if early_stopping.early_stop: print("Early stopping") break #after every epoch, show the center bias paramters for name, param in model.named_parameters(): if param.requires_grad: if "center_bias" in name: print(name, param, param.grad) # load the last checkpoint with the best model name = "checkpoint_batch_size_{}_lr_{}.pt".format(batch_size, lr) model.load_state_dict(torch.load(name)) return model, avg_train_losses, avg_valid_losses
class DWA(pytorchfw): def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False): super(DWA, self).__init__(model, rootdir, workname, main_device, trackgrad) self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio') self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals') self.audio_dumps_folder = None self.visual_dumps_folder = None self.main_device = main_device self.grid_unwarp = torch.from_numpy( warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device) self.set_tensor_scalar_item('l1') self.set_tensor_scalar_item('l2') if K == 4: self.set_tensor_scalar_item('l3') self.set_tensor_scalar_item('l4') self.set_tensor_scalar_item('loss_tracker') self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE) self.val_iterations = 0 def print_args(self): setup_logger('log_info', self.workdir + '/info_file.txt', FORMAT="[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s]") logger = logging.getLogger('log_info') self.print_info(logger) logger.info(f'\r\t Spectrogram data dir: {ROOT_DIR}\r' 'TRAINING PARAMETERS: \r\t' f'Run name: {self.workname}\r\t' f'Batch size {BATCH_SIZE} \r\t' f'Optimizer {OPTIMIZER} \r\t' f'Initializer {INITIALIZER} \r\t' f'Epochs {EPOCHS} \r\t' f'LR General: {LR} \r\t' f'SGD Momentum {MOMENTUM} \r\t' f'Weight Decay {WEIGHT_DECAY} \r\t' f'Pre-trained model: {PRETRAINED} \r' 'MODEL PARAMETERS \r\t' f'Nº instruments (K) {K} \r\t' f'U-Net activation: {ACTIVATION} \r\t' f'U-Net Input channels {INPUT_CHANNELS}\r\t' f'U-Net Batch normalization {USE_BN} \r\t') def set_optim(self, *args, **kwargs): if OPTIMIZER == 'adam': return torch.optim.Adam(*args, **kwargs) elif OPTIMIZER == 'SGD': return torch.optim.SGD(*args, **kwargs) else: raise Exception('Non considered optimizer. Implement it') def hyperparameters(self): self.dataparallel = False self.initializer = INITIALIZER self.EPOCHS = EPOCHS self.DWA_T = DWA_TEMP self.optimizer = self.set_optim(self.model.parameters(), momentum=MOMENTUM, lr=LR) self.LR = LR self.K = K self.scheduler = ReduceLROnPlateau(self.optimizer, patience=7, threshold=3e-4) def set_config(self): self.batch_size = BATCH_SIZE self.criterion = IndividualLosses(self.main_device) @config @set_training def train(self): self.print_args() self.audio_dumps_folder = os.path.join(self.audio_dumps_path, self.workname, 'train') create_folder(self.audio_dumps_folder) self.visual_dumps_folder = os.path.join(self.visual_dumps_path, self.workname, 'train') create_folder(self.visual_dumps_folder) training_data = UnetInput('train') self.train_loader = torch.utils.data.DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10) self.train_batches = len(self.train_loader) validation_data = UnetInput('val') self.val_loader = torch.utils.data.DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10) self.val_batches = len(self.val_loader) self.avg_cost = np.zeros([self.EPOCHS, self.K], dtype=np.float32) self.lambda_weight = np.ones([len(SOURCES_SUBSET), self.EPOCHS]) for self.epoch in range(self.start_epoch, self.EPOCHS): self.cost = list(torch.zeros(self.K)) # apply Dynamic Weight Average if self.epoch == 0 or self.epoch == 1: self.lambda_weight[:, self.epoch] = 1.0 else: if K == 2: self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0] self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1] exp_sum = (np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T)) self.lambda_weight[0, self.epoch] = 2 * np.exp(self.w_1 / self.DWA_T) / exp_sum self.lambda_weight[1, self.epoch] = 2 * np.exp(self.w_2 / self.DWA_T) / exp_sum elif K == 4: self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0] self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1] self.w_3 = self.avg_cost[self.epoch - 1, 2] / self.avg_cost[self.epoch - 2, 2] self.w_4 = self.avg_cost[self.epoch - 1, 3] / self.avg_cost[self.epoch - 2, 3] exp_sum = np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T) + np.exp( self.w_3 / self.DWA_T) + np.exp(self.w_4 / self.DWA_T) self.lambda_weight[0, self.epoch] = 4 * np.exp(self.w_1 / self.DWA_T) / exp_sum self.lambda_weight[1, self.epoch] = 4 * np.exp(self.w_2 / self.DWA_T) / exp_sum self.lambda_weight[2, self.epoch] = 4 * np.exp(self.w_3 / self.DWA_T) / exp_sum self.lambda_weight[3, self.epoch] = 4 * np.exp(self.w_4 / self.DWA_T) / exp_sum with train(self): self.run_epoch(self.train_iter_logger) self.scheduler.step(self.loss) with val(self): self.run_epoch() stop = self.EarlyStopChecker.check_improvement(self.loss_.data.tuple['val'].epoch_array.val, self.epoch) if stop: print('Early Stopping Epoch : [{0}]'.format(self.epoch)) break print( 'Epoch: {:04d} | TRAIN: {:.4f} {:.4f}'.format(self.epoch, self.avg_cost[self.epoch, 0], self.avg_cost[self.epoch, 1])) def train_epoch(self, logger): j = 0 self.train_iterations = len(iter(self.train_loader)) with tqdm(self.train_loader, desc='Epoch: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(self): for inputs, visualization in pbar: try: self.absolute_iter += 1 inputs = self._allocate_tensor(inputs) output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs) self.component_losses = self.criterion(output) if K == 2: [self.l1, self.l2] = self.component_losses self.loss_tracker = self.l1 + self.l2 elif K == 4: [self.l1, self.l2, self.l3, self.l4] = self.component_losses self.loss_tracker = self.l1 + self.l2 + self.l3 + self.l4 self.loss = torch.mean( sum(self.lambda_weight[i, self.epoch] * self.component_losses[i] for i in range(self.K))) self.optimizer.zero_grad() self.loss.backward() self.avg_cost[self.epoch] += torch.stack( self.cost).detach().cpu().clone().numpy() / self.train_batches self.gradients() self.optimizer.step() if K == 2: self.cost = [self.l1, self.l2] elif K == 4: self.cost = [self.l1, self.l2, self.l3, self.l4] pbar.set_postfix(loss=self.loss) self.loss_.data.print_logger(self.epoch, j, self.train_iterations, logger) self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization) j += 1 except Exception as e: try: self.save_checkpoint(filename=os.path.join(self.workdir, 'checkpoint_backup.pth')) except: self.err_logger.error('Failed to deal with exception. Couldnt save backup at {0} \n' .format(os.path.join(self.workdir, 'checkpoint_backup.pth'))) self.err_logger.error(str(e)) raise e for tsi in self.tensor_scalar_items: setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state)) for idx, src in enumerate(SOURCES_SUBSET): self.writer.add_scalars('weights', {'W_' + src: self.lambda_weight[idx, self.epoch].item()}, self.epoch) self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization) self.save_checkpoint() def validate_epoch(self): with tqdm(self.val_loader, desc='Validation: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter( self): for inputs, visualization in pbar: self.val_iterations += 1 self.loss_.data.update_timed() inputs = self._allocate_tensor(inputs) output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs) self.component_losses = self.criterion(output) if K == 2: [self.l1, self.l2] = self.component_losses self.loss_tracker = self.l1 + self.l2 elif K == 4: [self.l1, self.l2, self.l3, self.l4] = self.component_losses self.loss_tracker = self.l1 + self.l2 + self.l3 + self.l4 self.loss = torch.mean( sum(self.lambda_weight[i, self.epoch] * self.component_losses[i] for i in range(self.K))) self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization) pbar.set_postfix(loss=self.loss) for tsi in self.tensor_scalar_items: setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state)) self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization) @checkpoint_on_key @assert_workdir def save_checkpoint(self, filename=None): state = { 'epoch': self.epoch + 1, 'iter': self.absolute_iter + 1, 'arch': self.model_version, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'loss': self.loss_, 'key': self.key, 'scheduler': self.scheduler.state_dict() } if filename is None: filename = os.path.join(self.workdir, self.checkpoint_name) elif isinstance(filename, str): filename = os.path.join(self.workdir, filename) print('Saving checkpoint at : {}'.format(filename)) torch.save(state, filename) if self.loss_tracker_.data.is_best: shutil.copyfile(filename, os.path.join(self.workdir, 'best' + self.checkpoint_name)) print('Checkpoint saved successfully') def tensorboard_writer(self, loss, output, gt, absolute_iter, visualization): if self.state == 'train': iter_val = absolute_iter elif self.state == 'val': iter_val = self.val_iterations if self.iterating: if iter_val % PARAMETER_SAVE_FREQUENCY == 0: text = visualization[1] self.writer.add_text('Filepath', text[-1], iter_val) phase = visualization[0].detach().cpu().clone().numpy() gt_mags_sq, pred_mags_sq, gt_mags, mix_mag, gt_masks, pred_masks = output if len(text) == BATCH_SIZE: grid_unwarp = self.grid_unwarp else: # for the last batch, where the number of samples are generally lesser than the batch_size grid_unwarp = torch.from_numpy( warpgrid(len(text), NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device) pred_masks_linear = linearize_log_freq_scale(pred_masks, grid_unwarp) gt_masks_linear = linearize_log_freq_scale(gt_masks, grid_unwarp) oracle_spec = (mix_mag * gt_masks_linear) pred_spec = (mix_mag * pred_masks_linear) for i, sample in enumerate(text): sample_id = os.path.basename(sample)[:-4] folder_name = os.path.basename(os.path.dirname(sample)) pred_audio_out_folder = os.path.join(self.audio_dumps_folder, folder_name, sample_id) create_folder(pred_audio_out_folder) visuals_out_folder = os.path.join(self.visual_dumps_folder, folder_name, sample_id) create_folder(visuals_out_folder) for j, source in enumerate(SOURCES_SUBSET): gt_audio = torch.from_numpy( istft_reconstruction(gt_mags.detach().cpu().numpy()[i][j], phase[i][0], HOP_LENGTH)) pred_audio = torch.from_numpy( istft_reconstruction(pred_spec.detach().cpu().numpy()[i][j], phase[i][0], HOP_LENGTH)) librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'GT_' + source + '.wav'), gt_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE) librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'PR_' + source + '.wav'), pred_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE) ### PLOTTING MAG SPECTROGRAMS ### save_spectrogram(gt_mags[i][j].unsqueeze(0).detach().cpu(), os.path.join(visuals_out_folder, source), '_MAG_GT.png') save_spectrogram(oracle_spec[i][j].unsqueeze(0).detach().cpu(), os.path.join(visuals_out_folder, source), '_MAG_ORACLE.png') save_spectrogram(pred_spec[i][j].unsqueeze(0).detach().cpu(), os.path.join(visuals_out_folder, source), '_MAG_ESTIMATE.png') ### PLOTTING MAG SPECTROGRAMS ### plot_spectrogram(self.writer, gt_mags.detach().cpu().view(-1, 1, 512, 256)[:8], self.state + '_GT_MAG', iter_val) plot_spectrogram(self.writer, (pred_masks_linear * mix_mag).detach().cpu().view(-1, 1, 512, 256)[:8], self.state + '_PRED_MAG', iter_val) else: if K == 2: self.writer.add_scalars(self.state + ' losses_epoch', {'Voice Est Loss': self.l1}, self.epoch) self.writer.add_scalars(self.state + ' losses_epoch', {'Acc Est Loss': self.l2}, self.epoch) elif K == 4: self.writer.add_scalars(self.state + ' losses_epoch', {'Voice Est Loss': self.l1}, self.epoch) self.writer.add_scalars(self.state + ' losses_epoch', {'Drums Est Loss': self.l2}, self.epoch) self.writer.add_scalars(self.state + ' losses_epoch', {'Bass Est Loss': self.l3}, self.epoch) self.writer.add_scalars(self.state + ' losses_epoch', {'Other Est Loss': self.l4}, self.epoch)
from keras.preprocessing.sequence import pad_sequences from sklearn.model_selection import train_test_split from pytorch_transformers import XLNetModel, XLNetTokenizer, XLNetForSequenceClassification from pytorch_transformers import AdamW from tqdm import tqdm, trange import pandas as pd import io import numpy as np import matplotlib.pyplot as plt import datetime from utils.EarlyStopping import EarlyStopping es = EarlyStopping() device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") print("device: {}".format(device)) df = pd.read_csv("./data/train.csv", encoding='utf-8') sentences = df['comment_text'].values test_df = pd.read_csv("./data/test.csv", encoding='utf-8') test_sentences = test_df['comment_text'] # define output dataframe sample = pd.read_csv("./data/sample_submission.csv") for label in [
optimizer, scheduler = get_optimizer_and_lr_scheduler( model, args['optimizer'], args['learning_rate'], args['weight_decay'], steps_per_epoch=len(train_loader), epochs=args['epochs']) save_model_folder = f"../save_model/{args['dataset']}/{args['model_name']}" shutil.rmtree(save_model_folder, ignore_errors=True) os.makedirs(save_model_folder, exist_ok=True) early_stopping = EarlyStopping(patience=args['patience'], save_model_folder=save_model_folder, save_model_name=args['model_name']) loss_func = nn.CrossEntropyLoss() train_steps = 0 for epoch in range(args['epochs']): model.train() train_y_trues = [] train_y_predicts = [] train_total_loss = 0.0 train_loader_tqdm = tqdm(train_loader, ncols=120) for batch, (input_nodes, output_nodes, blocks) in enumerate(train_loader_tqdm):
optimizer, scheduler = get_optimizer_and_lr_scheduler( model, args['optimizer'], args['learning_rate'], args['weight_decay'], steps_per_epoch=len(train_loader), epochs=args['epochs']) save_model_folder = f"../save_model/{args['dataset']}/{args['model_name']}" shutil.rmtree(save_model_folder, ignore_errors=True) os.makedirs(save_model_folder, exist_ok=True) early_stopping = EarlyStopping(patience=args['patience'], save_model_folder=save_model_folder, save_model_name=args['model_name']) loss_func = nn.BCELoss() train_steps = 0 best_validate_RMSE, final_result = None, None for epoch in range(args['epochs']): model.train() train_y_trues = [] train_y_predicts = [] train_total_loss = 0.0 train_loader_tqdm = tqdm(train_loader, ncols=120)
class Baseline(pytorchfw): def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False): super(Baseline, self).__init__(model, rootdir, workname, main_device, trackgrad) self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio') self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals') self.main_device=main_device self.grid_unwarp = torch.from_numpy( warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device) self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE) self.val_iterations = 0 def print_args(self): setup_logger('log_info', self.workdir + '/info_file.txt', FORMAT="[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s]") logger = logging.getLogger('log_info') self.print_info(logger) logger.info(f'\r\t Spectrogram data dir: {ROOT_DIR}\r' 'TRAINING PARAMETERS: \r\t' f'Run name: {self.workname}\r\t' f'Batch size {BATCH_SIZE} \r\t' f'Optimizer {OPTIMIZER} \r\t' f'Initializer {INITIALIZER} \r\t' f'Epochs {EPOCHS} \r\t' f'LR General: {LR} \r\t' f'SGD Momentum {MOMENTUM} \r\t' f'Weight Decay {WEIGHT_DECAY} \r\t' f'Pre-trained model: {PRETRAINED} \r' 'MODEL PARAMETERS \r\t' f'Nº instruments (K) {K} \r\t' f'U-Net activation: {ACTIVATION} \r\t' f'U-Net Input channels {INPUT_CHANNELS}\r\t' f'U-Net Batch normalization {USE_BN} \r\t') def set_optim(self, *args, **kwargs): if OPTIMIZER == 'adam': return torch.optim.Adam(*args, **kwargs) elif OPTIMIZER == 'SGD': return torch.optim.SGD(*args, **kwargs) else: raise Exception('Non considered optimizer. Implement it') def hyperparameters(self): self.dataparallel = False self.initializer = INITIALIZER self.EPOCHS = EPOCHS self.optimizer = self.set_optim(self.model.parameters(), momentum=MOMENTUM, lr=LR) self.LR = LR self.scheduler = ReduceLROnPlateau(self.optimizer, patience=7, threshold=3e-4) def set_config(self): self.batch_size = BATCH_SIZE self.criterion = SingleSourceDirectLoss(self.main_device) @config @set_training def train(self): self.print_args() self.audio_dumps_folder = os.path.join(self.audio_dumps_path, self.workname, 'train') create_folder(self.audio_dumps_folder) self.visual_dumps_folder = os.path.join(self.visual_dumps_path, self.workname, 'train') create_folder(self.visual_dumps_folder) self.optimizer = self.set_optim(self.model.parameters(), momentum=MOMENTUM, lr=LR) training_data = UnetInput('train') self.train_loader = torch.utils.data.DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10) validation_data = UnetInput('val') self.val_loader = torch.utils.data.DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10) for self.epoch in range(self.start_epoch, self.EPOCHS): with train(self): self.run_epoch(self.train_iter_logger) self.scheduler.step(self.loss) with val(self): self.run_epoch() stop = self.EarlyStopChecker.check_improvement(self.loss_.data.tuple['val'].epoch_array.val, self.epoch) if stop: print('Early Stopping Epoch : [{0}], ' 'Best Checkpoint Epoch : [{1}]'.format(self.epoch, self.EarlyStopChecker.best_epoch)) break def train_epoch(self, logger): j = 0 self.train_iterations = len(iter(self.train_loader)) with tqdm(self.train_loader, desc='Epoch: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(self): for inputs, visualization in pbar: try: self.absolute_iter += 1 inputs = self._allocate_tensor(inputs) output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs) self.loss = self.criterion(output) self.optimizer.zero_grad() self.loss.backward() self.gradients() self.optimizer.step() pbar.set_postfix(loss=self.loss) self.loss_.data.print_logger(self.epoch, j, self.train_iterations, logger) self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization) j += 1 except Exception as e: try: self.save_checkpoint(filename=os.path.join(self.workdir, 'checkpoint_backup.pth')) except: self.err_logger.error('Failed to deal with exception. Could not save backup at {0} \n' .format(os.path.join(self.workdir, 'checkpoint_backup.pth'))) self.err_logger.error(str(e)) raise e for tsi in self.tensor_scalar_items: setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state)) self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization) self.save_checkpoint() def validate_epoch(self): with tqdm(self.val_loader, desc='Validation: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter( self): for inputs, visualization in pbar: self.val_iterations += 1 self.loss_.data.update_timed() inputs = self._allocate_tensor(inputs) output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs) self.loss = self.criterion(output) self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization) pbar.set_postfix(loss=self.loss) for tsi in self.tensor_scalar_items: setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state)) self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization) def tensorboard_writer(self, loss, output, gt, absolute_iter, visualization): if self.state == 'train': iter_val = absolute_iter elif self.state == 'val': iter_val = self.val_iterations if iter_val % PARAMETER_SAVE_FREQUENCY == 0: text = visualization[1] self.writer.add_text('Filepath', text[-1], iter_val) phase = visualization[0].detach().cpu().clone().numpy() gt_mags_sq, pred_mags_sq, gt_mags, mix_mag, gt_masks, pred_masks = output if len(text) == BATCH_SIZE: grid_unwarp = self.grid_unwarp else: # for the last batch, where the number of samples are generally lesser than the batch_size grid_unwarp = torch.from_numpy( warpgrid(len(text), NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device) pred_masks_linear = linearize_log_freq_scale(pred_masks, grid_unwarp) gt_masks_linear = linearize_log_freq_scale(gt_masks, grid_unwarp) oracle_spec = (mix_mag * gt_masks_linear) pred_spec = (mix_mag * pred_masks_linear) j = ISOLATED_SOURCE_ID source = SOURCES_SUBSET[j] for i, sample in enumerate(text): sample_id = os.path.basename(sample)[:-4] folder_name = os.path.basename(os.path.dirname(sample)) pred_audio_out_folder = os.path.join(self.audio_dumps_folder, folder_name, sample_id) create_folder(pred_audio_out_folder) visuals_out_folder = os.path.join(self.visual_dumps_folder, folder_name, sample_id) create_folder(visuals_out_folder) gt_audio = torch.from_numpy( istft_reconstruction(gt_mags.detach().cpu().numpy()[i][j], phase[i][0], HOP_LENGTH)) pred_audio = torch.from_numpy( istft_reconstruction(pred_spec.detach().cpu().numpy()[i][0], phase[i][0], HOP_LENGTH)) librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'GT_' + source + '.wav'), gt_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE) librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'PR_' + source + '.wav'), pred_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE) ### SAVING MAG SPECTROGRAMS ### save_spectrogram(gt_mags[i][j].unsqueeze(0).detach().cpu(), os.path.join(visuals_out_folder, source), '_MAG_GT.png') save_spectrogram(oracle_spec[i][j].unsqueeze(0).detach().cpu(), os.path.join(visuals_out_folder, source), '_MAG_ORACLE.png') save_spectrogram(pred_spec[i][0].unsqueeze(0).detach().cpu(), os.path.join(visuals_out_folder, source), '_MAG_ESTIMATE.png') ### PLOTTING MAG SPECTROGRAMS ON TENSORBOARD ### plot_spectrogram(self.writer, gt_mags[:, j].detach().cpu().view(-1, 1, 512, 256)[:8], self.state + '_GT_MAG', iter_val) plot_spectrogram(self.writer, (pred_masks_linear * mix_mag).detach().cpu().view(-1, 1, 512, 256)[:8], self.state + '_PRED_MAG', iter_val)
def train_model(model, batch_size, patience, n_epochs, gpu): # to track the training loss as the model trains train_losses = [] # to track the validation loss as the model trains valid_losses = [] # to track the average training loss per epoch as the model trains avg_train_losses = [] # to track the average validation loss per epoch as the model trains avg_valid_losses = [] # initialize the early_stopping object early_stopping = EarlyStopping(patience=patience, verbose=True) for epoch in range(1, n_epochs + 1): ################### # train the model # ################### model.train() # prep model for training t = tqdm(iter(train_loader), desc="[Train on Epoch {}/{}]".format(epoch, n_epochs)) for i, example in enumerate(t): #start at index 0 # get the inputs data = example["image"] #print("data size: {}".format(data.size())) target = example["fixations"] #print("target size: {}".format(target.size())) # clear the gradients of all optimized variables optimizer.zero_grad() #push data and targets to gpu if gpu: if torch.cuda.is_available(): data = data.to('cuda') target = target.to('cuda') # forward pass: compute predicted outputs by passing inputs to the model output = model(data) #print("output size: {}".format(output.size())) # calculate the loss loss = myLoss(output, target) # backward pass: compute gradient of the loss with respect to model parameters loss.backward() # perform a single optimization step (parameter update) optimizer.step() # record training loss train_losses.append(loss.item()) #print("On iteration {} loss is {:.3f}".format(i+1, loss.item())) #for the first epoch, plot loss per iteration to have a quick overview of the early training phase iteration = i + 1 #plot is always appending the newest value, so just give the last item if the list if epoch == 1: plotter_train.plot('loss', 'train', 'Loss per Iteration', iteration, train_losses[-1], batch_size, lr, 'iteration') ###################### # validate the model # ###################### model.eval() # prep model for evaluation t = tqdm(iter(val_loader), desc="[Valid on Epoch {}/{}]".format(epoch, n_epochs)) for i, example in enumerate(t): # get the inputs data = example["image"] #print("input sum: {}".format(torch.sum(data))) target = example["fixations"] #push data and targets to gpu if gpu: if torch.cuda.is_available(): data = data.to('cuda') target = target.to('cuda') #print("target sum: {}".format(torch.sum(target))) # forward pass: compute predicted outputs by passing inputs to the model output = model(data) #print("output sum: {}".format(torch.sum(output))) # calculate the loss loss = myLoss(output, target) # record validation loss valid_losses.append(loss.item()) #plotter_val.plot('loss', 'val', 'Loss per Iteration', iteration, valid_losses[-1]) # print training/validation statistics # calculate average loss over an epoch train_loss = np.average(train_losses) valid_loss = np.average(valid_losses) avg_train_losses.append(train_loss) avg_valid_losses.append(valid_loss) epoch_len = str(n_epochs) print_msg = ('[{}/{}] '.format(epoch, epoch_len) + 'train_loss: {:.5f} '.format(train_loss) + 'valid_loss: {:.5f}'.format(valid_loss)) print(print_msg) #plot average loss for this epoch plotter_eval.plot('loss', 'train', 'Loss per Epoch', epoch, train_loss, batch_size, lr, 'epoch') plotter_eval.plot('loss', 'val', 'Loss per Epoch', epoch, valid_loss, batch_size, lr, 'epoch') # clear lists to track next epoch train_losses = [] valid_losses = [] # early_stopping needs the validation loss to check if it has decresed, # and if it has, it will make a checkpoint of the current model early_stopping(valid_loss, model, batch_size, lr) if early_stopping.early_stop: print("Early stopping") break # load the last checkpoint with the best model name = "checkpoint_batch_size_{}_lr_{}.pt".format(batch_size, lr) model.load_state_dict(torch.load(name)) return model, avg_train_losses, avg_valid_losses