示例#1
0
class DWA(pytorchfw):
    def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False):
        super(DWA, self).__init__(model, rootdir, workname, main_device, trackgrad)
        self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio')
        self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals')
        self.audio_dumps_folder = None
        self.visual_dumps_folder = None
        self.main_device = main_device
        self.grid_unwarp = torch.from_numpy(
            warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device)

        self.set_tensor_scalar_item('l1')
        self.set_tensor_scalar_item('l2')
        if K == 4:
            self.set_tensor_scalar_item('l3')
            self.set_tensor_scalar_item('l4')
        self.set_tensor_scalar_item('loss_tracker')
        self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE)
        self.val_iterations = 0

    def print_args(self):
        setup_logger('log_info', self.workdir + '/info_file.txt',
                     FORMAT="[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s]")
        logger = logging.getLogger('log_info')
        self.print_info(logger)
        logger.info(f'\r\t Spectrogram data dir: {ROOT_DIR}\r'
                    'TRAINING PARAMETERS: \r\t'
                    f'Run name: {self.workname}\r\t'
                    f'Batch size {BATCH_SIZE} \r\t'
                    f'Optimizer {OPTIMIZER} \r\t'
                    f'Initializer {INITIALIZER} \r\t'
                    f'Epochs {EPOCHS} \r\t'
                    f'LR General: {LR} \r\t'
                    f'SGD Momentum {MOMENTUM} \r\t'
                    f'Weight Decay {WEIGHT_DECAY} \r\t'
                    f'Pre-trained model:  {PRETRAINED} \r'
                    'MODEL PARAMETERS \r\t'
                    f'Nº instruments (K) {K} \r\t'
                    f'U-Net activation: {ACTIVATION} \r\t'
                    f'U-Net Input channels {INPUT_CHANNELS}\r\t'
                    f'U-Net Batch normalization {USE_BN} \r\t')

    def set_optim(self, *args, **kwargs):
        if OPTIMIZER == 'adam':
            return torch.optim.Adam(*args, **kwargs)
        elif OPTIMIZER == 'SGD':
            return torch.optim.SGD(*args, **kwargs)
        else:
            raise Exception('Non considered optimizer. Implement it')

    def hyperparameters(self):
        self.dataparallel = False
        self.initializer = INITIALIZER
        self.EPOCHS = EPOCHS
        self.DWA_T = DWA_TEMP
        self.optimizer = self.set_optim(self.model.parameters(), momentum=MOMENTUM, lr=LR)
        self.LR = LR
        self.K = K
        self.scheduler = ReduceLROnPlateau(self.optimizer, patience=7, threshold=3e-4)

    def set_config(self):
        self.batch_size = BATCH_SIZE
        self.criterion = IndividualLosses(self.main_device)

    @config
    @set_training
    def train(self):

        self.print_args()
        self.audio_dumps_folder = os.path.join(self.audio_dumps_path, self.workname, 'train')
        create_folder(self.audio_dumps_folder)
        self.visual_dumps_folder = os.path.join(self.visual_dumps_path, self.workname, 'train')
        create_folder(self.visual_dumps_folder)

        training_data = UnetInput('train')
        self.train_loader = torch.utils.data.DataLoader(training_data,
                                                        batch_size=BATCH_SIZE,
                                                        shuffle=True,
                                                        num_workers=10)
        self.train_batches = len(self.train_loader)
        validation_data = UnetInput('val')
        self.val_loader = torch.utils.data.DataLoader(validation_data,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=True,
                                                      num_workers=10)
        self.val_batches = len(self.val_loader)

        self.avg_cost = np.zeros([self.EPOCHS, self.K], dtype=np.float32)
        self.lambda_weight = np.ones([len(SOURCES_SUBSET), self.EPOCHS])
        for self.epoch in range(self.start_epoch, self.EPOCHS):
            self.cost = list(torch.zeros(self.K))

            # apply Dynamic Weight Average
            if self.epoch == 0 or self.epoch == 1:
                self.lambda_weight[:, self.epoch] = 1.0
            else:
                if K == 2:
                    self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0]
                    self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1]
                    exp_sum = (np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T))
                    self.lambda_weight[0, self.epoch] = 2 * np.exp(self.w_1 / self.DWA_T) / exp_sum
                    self.lambda_weight[1, self.epoch] = 2 * np.exp(self.w_2 / self.DWA_T) / exp_sum
                elif K == 4:
                    self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0]
                    self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1]
                    self.w_3 = self.avg_cost[self.epoch - 1, 2] / self.avg_cost[self.epoch - 2, 2]
                    self.w_4 = self.avg_cost[self.epoch - 1, 3] / self.avg_cost[self.epoch - 2, 3]
                    exp_sum = np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T) + np.exp(
                        self.w_3 / self.DWA_T) + np.exp(self.w_4 / self.DWA_T)
                    self.lambda_weight[0, self.epoch] = 4 * np.exp(self.w_1 / self.DWA_T) / exp_sum
                    self.lambda_weight[1, self.epoch] = 4 * np.exp(self.w_2 / self.DWA_T) / exp_sum
                    self.lambda_weight[2, self.epoch] = 4 * np.exp(self.w_3 / self.DWA_T) / exp_sum
                    self.lambda_weight[3, self.epoch] = 4 * np.exp(self.w_4 / self.DWA_T) / exp_sum

            with train(self):
                self.run_epoch(self.train_iter_logger)
            self.scheduler.step(self.loss)
            with val(self):
                self.run_epoch()

            stop = self.EarlyStopChecker.check_improvement(self.loss_.data.tuple['val'].epoch_array.val,
                                                           self.epoch)
            if stop:
                print('Early Stopping Epoch : [{0}]'.format(self.epoch))
                break
            print(
                'Epoch: {:04d} | TRAIN: {:.4f} {:.4f}'.format(self.epoch,
                                                              self.avg_cost[self.epoch, 0],
                                                              self.avg_cost[self.epoch, 1]))

    def train_epoch(self, logger):
        j = 0
        self.train_iterations = len(iter(self.train_loader))
        with tqdm(self.train_loader, desc='Epoch: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(self):
            for inputs, visualization in pbar:
                try:
                    self.absolute_iter += 1
                    inputs = self._allocate_tensor(inputs)
                    output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs)
                    self.component_losses = self.criterion(output)
                    if K == 2:
                        [self.l1, self.l2] = self.component_losses
                        self.loss_tracker = self.l1 + self.l2
                    elif K == 4:
                        [self.l1, self.l2, self.l3, self.l4] = self.component_losses
                        self.loss_tracker = self.l1 + self.l2 + self.l3 + self.l4
                    self.loss = torch.mean(
                        sum(self.lambda_weight[i, self.epoch] * self.component_losses[i] for i in range(self.K)))
                    self.optimizer.zero_grad()
                    self.loss.backward()
                    self.avg_cost[self.epoch] += torch.stack(
                        self.cost).detach().cpu().clone().numpy() / self.train_batches
                    self.gradients()
                    self.optimizer.step()
                    if K == 2:
                        self.cost = [self.l1, self.l2]
                    elif K == 4:
                        self.cost = [self.l1, self.l2, self.l3, self.l4]
                    pbar.set_postfix(loss=self.loss)
                    self.loss_.data.print_logger(self.epoch, j, self.train_iterations, logger)
                    self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
                    j += 1

                except Exception as e:
                    try:
                        self.save_checkpoint(filename=os.path.join(self.workdir, 'checkpoint_backup.pth'))
                    except:
                        self.err_logger.error('Failed to deal with exception. Couldnt save backup at {0} \n'
                                              .format(os.path.join(self.workdir, 'checkpoint_backup.pth')))
                    self.err_logger.error(str(e))
                    raise e

        for tsi in self.tensor_scalar_items:
            setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state))
        for idx, src in enumerate(SOURCES_SUBSET):
            self.writer.add_scalars('weights', {'W_' + src: self.lambda_weight[idx, self.epoch].item()}, self.epoch)
        self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
        self.save_checkpoint()

    def validate_epoch(self):
        with tqdm(self.val_loader, desc='Validation: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(
                self):
            for inputs, visualization in pbar:
                self.val_iterations += 1
                self.loss_.data.update_timed()
                inputs = self._allocate_tensor(inputs)
                output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs)
                self.component_losses = self.criterion(output)
                if K == 2:
                    [self.l1, self.l2] = self.component_losses
                    self.loss_tracker = self.l1 + self.l2
                elif K == 4:
                    [self.l1, self.l2, self.l3, self.l4] = self.component_losses
                    self.loss_tracker = self.l1 + self.l2 + self.l3 + self.l4
                self.loss = torch.mean(
                    sum(self.lambda_weight[i, self.epoch] * self.component_losses[i] for i in range(self.K)))
                self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
                pbar.set_postfix(loss=self.loss)
        for tsi in self.tensor_scalar_items:
            setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state))
        self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)

    @checkpoint_on_key
    @assert_workdir
    def save_checkpoint(self, filename=None):
        state = {
            'epoch': self.epoch + 1,
            'iter': self.absolute_iter + 1,
            'arch': self.model_version,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'loss': self.loss_,
            'key': self.key,
            'scheduler': self.scheduler.state_dict()
        }
        if filename is None:
            filename = os.path.join(self.workdir, self.checkpoint_name)

        elif isinstance(filename, str):
            filename = os.path.join(self.workdir, filename)
        print('Saving checkpoint at : {}'.format(filename))
        torch.save(state, filename)
        if self.loss_tracker_.data.is_best:
            shutil.copyfile(filename, os.path.join(self.workdir, 'best' + self.checkpoint_name))
        print('Checkpoint saved successfully')

    def tensorboard_writer(self, loss, output, gt, absolute_iter, visualization):
        if self.state == 'train':
            iter_val = absolute_iter
        elif self.state == 'val':
            iter_val = self.val_iterations

        if self.iterating:
            if iter_val % PARAMETER_SAVE_FREQUENCY == 0:
                text = visualization[1]
                self.writer.add_text('Filepath', text[-1], iter_val)
                phase = visualization[0].detach().cpu().clone().numpy()
                gt_mags_sq, pred_mags_sq, gt_mags, mix_mag, gt_masks, pred_masks = output
                if len(text) == BATCH_SIZE:
                    grid_unwarp = self.grid_unwarp
                else:  # for the last batch, where the number of samples are generally lesser than the batch_size
                    grid_unwarp = torch.from_numpy(
                        warpgrid(len(text), NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device)
                pred_masks_linear = linearize_log_freq_scale(pred_masks, grid_unwarp)
                gt_masks_linear = linearize_log_freq_scale(gt_masks, grid_unwarp)
                oracle_spec = (mix_mag * gt_masks_linear)
                pred_spec = (mix_mag * pred_masks_linear)

                for i, sample in enumerate(text):
                    sample_id = os.path.basename(sample)[:-4]
                    folder_name = os.path.basename(os.path.dirname(sample))
                    pred_audio_out_folder = os.path.join(self.audio_dumps_folder, folder_name, sample_id)
                    create_folder(pred_audio_out_folder)
                    visuals_out_folder = os.path.join(self.visual_dumps_folder, folder_name, sample_id)
                    create_folder(visuals_out_folder)

                    for j, source in enumerate(SOURCES_SUBSET):
                        gt_audio = torch.from_numpy(
                            istft_reconstruction(gt_mags.detach().cpu().numpy()[i][j], phase[i][0], HOP_LENGTH))
                        pred_audio = torch.from_numpy(
                            istft_reconstruction(pred_spec.detach().cpu().numpy()[i][j], phase[i][0], HOP_LENGTH))
                        librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'GT_' + source + '.wav'),
                                                 gt_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE)
                        librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'PR_' + source + '.wav'),
                                                 pred_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE)

                        ### PLOTTING MAG SPECTROGRAMS ###
                        save_spectrogram(gt_mags[i][j].unsqueeze(0).detach().cpu(),
                                         os.path.join(visuals_out_folder, source), '_MAG_GT.png')
                        save_spectrogram(oracle_spec[i][j].unsqueeze(0).detach().cpu(),
                                         os.path.join(visuals_out_folder, source), '_MAG_ORACLE.png')
                        save_spectrogram(pred_spec[i][j].unsqueeze(0).detach().cpu(),
                                         os.path.join(visuals_out_folder, source), '_MAG_ESTIMATE.png')

                ### PLOTTING MAG SPECTROGRAMS ###
                plot_spectrogram(self.writer, gt_mags.detach().cpu().view(-1, 1, 512, 256)[:8],
                                 self.state + '_GT_MAG', iter_val)
                plot_spectrogram(self.writer, (pred_masks_linear * mix_mag).detach().cpu().view(-1, 1, 512, 256)[:8],
                                 self.state + '_PRED_MAG', iter_val)

        else:
            if K == 2:
                self.writer.add_scalars(self.state + ' losses_epoch', {'Voice Est Loss': self.l1}, self.epoch)
                self.writer.add_scalars(self.state + ' losses_epoch', {'Acc Est Loss': self.l2}, self.epoch)
            elif K == 4:
                self.writer.add_scalars(self.state + ' losses_epoch', {'Voice Est Loss': self.l1}, self.epoch)
                self.writer.add_scalars(self.state + ' losses_epoch', {'Drums Est Loss': self.l2}, self.epoch)
                self.writer.add_scalars(self.state + ' losses_epoch', {'Bass Est Loss': self.l3}, self.epoch)
                self.writer.add_scalars(self.state + ' losses_epoch', {'Other Est Loss': self.l4}, self.epoch)
class Baseline(pytorchfw):
    def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False):
        super(Baseline, self).__init__(model, rootdir, workname, main_device, trackgrad)
        self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio')
        self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals')
        self.main_device=main_device
        self.grid_unwarp = torch.from_numpy(
            warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device)
        self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE)
        self.val_iterations = 0

    def print_args(self):
        setup_logger('log_info', self.workdir + '/info_file.txt',
                     FORMAT="[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s]")
        logger = logging.getLogger('log_info')
        self.print_info(logger)
        logger.info(f'\r\t Spectrogram data dir: {ROOT_DIR}\r'
                    'TRAINING PARAMETERS: \r\t'
                    f'Run name: {self.workname}\r\t'
                    f'Batch size {BATCH_SIZE} \r\t'
                    f'Optimizer {OPTIMIZER} \r\t'
                    f'Initializer {INITIALIZER} \r\t'
                    f'Epochs {EPOCHS} \r\t'
                    f'LR General: {LR} \r\t'
                    f'SGD Momentum {MOMENTUM} \r\t'
                    f'Weight Decay {WEIGHT_DECAY} \r\t'
                    f'Pre-trained model:  {PRETRAINED} \r'
                    'MODEL PARAMETERS \r\t'
                    f'Nº instruments (K) {K} \r\t'
                    f'U-Net activation: {ACTIVATION} \r\t'
                    f'U-Net Input channels {INPUT_CHANNELS}\r\t'
                    f'U-Net Batch normalization {USE_BN} \r\t')

    def set_optim(self, *args, **kwargs):
        if OPTIMIZER == 'adam':
            return torch.optim.Adam(*args, **kwargs)
        elif OPTIMIZER == 'SGD':
            return torch.optim.SGD(*args, **kwargs)
        else:
            raise Exception('Non considered optimizer. Implement it')

    def hyperparameters(self):
        self.dataparallel = False
        self.initializer = INITIALIZER
        self.EPOCHS = EPOCHS
        self.optimizer = self.set_optim(self.model.parameters(), momentum=MOMENTUM, lr=LR)
        self.LR = LR
        self.scheduler = ReduceLROnPlateau(self.optimizer, patience=7, threshold=3e-4)

    def set_config(self):
        self.batch_size = BATCH_SIZE
        self.criterion = SingleSourceDirectLoss(self.main_device)

    @config
    @set_training
    def train(self):

        self.print_args()
        self.audio_dumps_folder = os.path.join(self.audio_dumps_path, self.workname, 'train')
        create_folder(self.audio_dumps_folder)
        self.visual_dumps_folder = os.path.join(self.visual_dumps_path, self.workname, 'train')
        create_folder(self.visual_dumps_folder)

        self.optimizer = self.set_optim(self.model.parameters(), momentum=MOMENTUM, lr=LR)
        training_data = UnetInput('train')
        self.train_loader = torch.utils.data.DataLoader(training_data,
                                                        batch_size=BATCH_SIZE,
                                                        shuffle=True,
                                                        num_workers=10)

        validation_data = UnetInput('val')
        self.val_loader = torch.utils.data.DataLoader(validation_data,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=True,
                                                      num_workers=10)
        for self.epoch in range(self.start_epoch, self.EPOCHS):
            with train(self):
                self.run_epoch(self.train_iter_logger)
            self.scheduler.step(self.loss)
            with val(self):
                self.run_epoch()
            stop = self.EarlyStopChecker.check_improvement(self.loss_.data.tuple['val'].epoch_array.val,
                                                           self.epoch)
            if stop:
                print('Early Stopping Epoch : [{0}], '
                      'Best Checkpoint Epoch : [{1}]'.format(self.epoch,
                                                             self.EarlyStopChecker.best_epoch))
                break

    def train_epoch(self, logger):
        j = 0
        self.train_iterations = len(iter(self.train_loader))
        with tqdm(self.train_loader, desc='Epoch: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(self):
            for inputs, visualization in pbar:
                try:
                    self.absolute_iter += 1
                    inputs = self._allocate_tensor(inputs)
                    output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs)
                    self.loss = self.criterion(output)
                    self.optimizer.zero_grad()
                    self.loss.backward()
                    self.gradients()
                    self.optimizer.step()
                    pbar.set_postfix(loss=self.loss)
                    self.loss_.data.print_logger(self.epoch, j, self.train_iterations, logger)
                    self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
                    j += 1
                except Exception as e:
                    try:
                        self.save_checkpoint(filename=os.path.join(self.workdir, 'checkpoint_backup.pth'))
                    except:
                        self.err_logger.error('Failed to deal with exception. Could not save backup at {0} \n'
                                              .format(os.path.join(self.workdir, 'checkpoint_backup.pth')))
                    self.err_logger.error(str(e))
                    raise e
        for tsi in self.tensor_scalar_items:
            setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state))
        self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
        self.save_checkpoint()

    def validate_epoch(self):
        with tqdm(self.val_loader, desc='Validation: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(
                self):
            for inputs, visualization in pbar:
                self.val_iterations += 1
                self.loss_.data.update_timed()
                inputs = self._allocate_tensor(inputs)
                output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs)
                self.loss = self.criterion(output)
                self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
                pbar.set_postfix(loss=self.loss)
        for tsi in self.tensor_scalar_items:
            setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state))
        self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)

    def tensorboard_writer(self, loss, output, gt, absolute_iter, visualization):
        if self.state == 'train':
            iter_val = absolute_iter
        elif self.state == 'val':
            iter_val = self.val_iterations

        if iter_val % PARAMETER_SAVE_FREQUENCY == 0:
            text = visualization[1]
            self.writer.add_text('Filepath', text[-1], iter_val)
            phase = visualization[0].detach().cpu().clone().numpy()
            gt_mags_sq, pred_mags_sq, gt_mags, mix_mag, gt_masks, pred_masks = output
            if len(text) == BATCH_SIZE:
                grid_unwarp = self.grid_unwarp
            else:  # for the last batch, where the number of samples are generally lesser than the batch_size
                grid_unwarp = torch.from_numpy(
                    warpgrid(len(text), NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device)
            pred_masks_linear = linearize_log_freq_scale(pred_masks, grid_unwarp)
            gt_masks_linear = linearize_log_freq_scale(gt_masks, grid_unwarp)
            oracle_spec = (mix_mag * gt_masks_linear)
            pred_spec = (mix_mag * pred_masks_linear)
            j = ISOLATED_SOURCE_ID
            source = SOURCES_SUBSET[j]

            for i, sample in enumerate(text):
                sample_id = os.path.basename(sample)[:-4]
                folder_name = os.path.basename(os.path.dirname(sample))
                pred_audio_out_folder = os.path.join(self.audio_dumps_folder, folder_name, sample_id)
                create_folder(pred_audio_out_folder)
                visuals_out_folder = os.path.join(self.visual_dumps_folder, folder_name, sample_id)
                create_folder(visuals_out_folder)

                gt_audio = torch.from_numpy(
                    istft_reconstruction(gt_mags.detach().cpu().numpy()[i][j], phase[i][0], HOP_LENGTH))
                pred_audio = torch.from_numpy(
                    istft_reconstruction(pred_spec.detach().cpu().numpy()[i][0], phase[i][0], HOP_LENGTH))
                librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'GT_' + source + '.wav'),
                                         gt_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE)
                librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'PR_' + source + '.wav'),
                                         pred_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE)

                ### SAVING MAG SPECTROGRAMS ###
                save_spectrogram(gt_mags[i][j].unsqueeze(0).detach().cpu(),
                                 os.path.join(visuals_out_folder, source), '_MAG_GT.png')
                save_spectrogram(oracle_spec[i][j].unsqueeze(0).detach().cpu(),
                                 os.path.join(visuals_out_folder, source), '_MAG_ORACLE.png')
                save_spectrogram(pred_spec[i][0].unsqueeze(0).detach().cpu(),
                                 os.path.join(visuals_out_folder, source), '_MAG_ESTIMATE.png')

            ### PLOTTING MAG SPECTROGRAMS ON TENSORBOARD ###
            plot_spectrogram(self.writer, gt_mags[:, j].detach().cpu().view(-1, 1, 512, 256)[:8],
                             self.state + '_GT_MAG', iter_val)
            plot_spectrogram(self.writer, (pred_masks_linear * mix_mag).detach().cpu().view(-1, 1, 512, 256)[:8],
                             self.state + '_PRED_MAG', iter_val)