def __init__(self,
                 model,
                 rootdir,
                 workname,
                 main_device=0,
                 trackgrad=False):
        super(SpecChannelUnetNoMask, self).__init__(model, rootdir, workname,
                                                    main_device, trackgrad)
        self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio')
        self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals')
        self.audio_dumps_folder = None
        self.visual_dumps_folder = None
        self.main_device = main_device
        self.grid_unwarp = torch.from_numpy(
            warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH,
                     warp=False)).to(self.main_device)

        self.set_tensor_scalar_item('l1')
        self.set_tensor_scalar_item('l2')
        if K == 4:
            self.set_tensor_scalar_item('l3')
            self.set_tensor_scalar_item('l4')
        self.set_tensor_scalar_item('loss_tracker')
        self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE)
        self.val_iterations = 0
Ejemplo n.º 2
0
 def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False):
     super(Baseline, self).__init__(model, rootdir, workname, main_device, trackgrad)
     self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio')
     self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals')
     self.main_device=main_device
     self.grid_unwarp = torch.from_numpy(
         warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device)
     self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE)
     self.val_iterations = 0
def train_model(model, batch_size, patience, n_epochs, gpu, plotter_train,
                plotter_eval):

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []

    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    for epoch in range(1, n_epochs + 1):

        ###################
        # train the model #
        ###################

        #schedule LR
        #scheduler.step()

        model.train()  # prep model for training
        t = tqdm(iter(train_loader),
                 desc="[Train on Epoch {}/{}]".format(epoch, n_epochs))
        for i, example in enumerate(t):  #start at index 0
            # get the inputs
            data = example["image"]
            #print("data size: {}".format(data.size()))
            target = example["fixations"]

            #print("target size: {}".format(target.size()))
            # clear the gradients of all optimized variables
            #optimizer_adam.zero_grad()
            #optimizer_sgd.zero_grad()
            optimizer.zero_grad()

            #push data and targets to gpu
            if gpu:
                if torch.cuda.is_available():
                    data = data.to('cuda')
                    target = target.to('cuda')

            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)

            #drop channel-dimension (is only 1) so that outputs will be of same size as targets (batch_size,100,100)
            #infer batch dimension as last batch won't have the full size of eg 128
            output = output.view(-1, target.size()[-1], target.size()[-2])

            #print("output size: {}".format(output.size()))
            # calculate the loss
            #loss = myLoss(output, target)
            loss = criterion(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            #perform a single optimization step (parameter update)
            #On first 200 iterations (1. epoch), only update adam parametes (non center bias).
            #Afterwards, alternate between the two optimizers: adam on even iterations, sgd on
            #uneven iterations
            #if (epoch == 1) & (i < 200):
            #    optimizer_adam.step()
            #elif i % 2 == 0:
            #    optimizer_adam.step()
            #else:
            #    optimizer_sgd.step()
            optimizer.step()
            # record training loss
            train_losses.append(loss.item())
            #print("On iteration {} loss is {:.3f}".format(i+1, loss.item()))
            #for the first epoch, plot loss per iteration to have a quick overview of the early training phase
            iteration = i + 1
            #plot is always appending the newest value, so just give the last item if the list
            if epoch == 1:
                plotter_train.plot('loss', 'train', 'Loss per Iteration',
                                   iteration, train_losses[-1], batch_size, lr,
                                   'iteration')

        ######################
        # validate the model #
        ######################
        model.eval()  # prep model for evaluation
        t = tqdm(iter(val_loader),
                 desc="[Valid on Epoch {}/{}]".format(epoch, n_epochs))
        for i, example in enumerate(t):
            # get the inputs
            data = example["image"]
            #print("input sum: {}".format(torch.sum(data)))
            target = example["fixations"]

            #push data and targets to gpu
            if gpu:
                if torch.cuda.is_available():
                    data = data.to('cuda')
                    target = target.to('cuda')

            #print("target sum: {}".format(torch.sum(target)))
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)

            #drop channel-dimension (is only 1) so that outputs will be of same size as targets (batch_size,100,100)
            output = output.view(-1, target.size()[-2], target.size()[-2])

            #print("output sum: {}".format(torch.sum(output)))
            # calculate the loss
            #loss = myLoss(output, target)
            loss = criterion(output, target)
            # record validation loss
            valid_losses.append(loss.item())
            #plotter_val.plot('loss', 'val', 'Loss per Iteration', iteration, valid_losses[-1])

        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        epoch_len = str(n_epochs)

        print_msg = ('[{}/{}] '.format(epoch, epoch_len) +
                     'train_loss: {:.5f} '.format(train_loss) +
                     'valid_loss: {:.5f}'.format(valid_loss))

        print(print_msg)

        #plot average loss for this epoch
        plotter_eval.plot('loss', 'train', 'Loss per Epoch', epoch, train_loss,
                          batch_size, lr, 'epoch')
        plotter_eval.plot('loss', 'val', 'Loss per Epoch', epoch, valid_loss,
                          batch_size, lr, 'epoch')

        # clear lists to track next epoch
        train_losses = []
        valid_losses = []

        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model, batch_size, lr)

        if early_stopping.early_stop:
            print("Early stopping")
            break

        #after every epoch, show the center bias paramters
        for name, param in model.named_parameters():
            if param.requires_grad:
                if "center_bias" in name:
                    print(name, param, param.grad)

    # load the last checkpoint with the best model
    name = "checkpoint_batch_size_{}_lr_{}.pt".format(batch_size, lr)
    model.load_state_dict(torch.load(name))

    return model, avg_train_losses, avg_valid_losses
Ejemplo n.º 4
0
class DWA(pytorchfw):
    def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False):
        super(DWA, self).__init__(model, rootdir, workname, main_device, trackgrad)
        self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio')
        self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals')
        self.audio_dumps_folder = None
        self.visual_dumps_folder = None
        self.main_device = main_device
        self.grid_unwarp = torch.from_numpy(
            warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device)

        self.set_tensor_scalar_item('l1')
        self.set_tensor_scalar_item('l2')
        if K == 4:
            self.set_tensor_scalar_item('l3')
            self.set_tensor_scalar_item('l4')
        self.set_tensor_scalar_item('loss_tracker')
        self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE)
        self.val_iterations = 0

    def print_args(self):
        setup_logger('log_info', self.workdir + '/info_file.txt',
                     FORMAT="[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s]")
        logger = logging.getLogger('log_info')
        self.print_info(logger)
        logger.info(f'\r\t Spectrogram data dir: {ROOT_DIR}\r'
                    'TRAINING PARAMETERS: \r\t'
                    f'Run name: {self.workname}\r\t'
                    f'Batch size {BATCH_SIZE} \r\t'
                    f'Optimizer {OPTIMIZER} \r\t'
                    f'Initializer {INITIALIZER} \r\t'
                    f'Epochs {EPOCHS} \r\t'
                    f'LR General: {LR} \r\t'
                    f'SGD Momentum {MOMENTUM} \r\t'
                    f'Weight Decay {WEIGHT_DECAY} \r\t'
                    f'Pre-trained model:  {PRETRAINED} \r'
                    'MODEL PARAMETERS \r\t'
                    f'Nº instruments (K) {K} \r\t'
                    f'U-Net activation: {ACTIVATION} \r\t'
                    f'U-Net Input channels {INPUT_CHANNELS}\r\t'
                    f'U-Net Batch normalization {USE_BN} \r\t')

    def set_optim(self, *args, **kwargs):
        if OPTIMIZER == 'adam':
            return torch.optim.Adam(*args, **kwargs)
        elif OPTIMIZER == 'SGD':
            return torch.optim.SGD(*args, **kwargs)
        else:
            raise Exception('Non considered optimizer. Implement it')

    def hyperparameters(self):
        self.dataparallel = False
        self.initializer = INITIALIZER
        self.EPOCHS = EPOCHS
        self.DWA_T = DWA_TEMP
        self.optimizer = self.set_optim(self.model.parameters(), momentum=MOMENTUM, lr=LR)
        self.LR = LR
        self.K = K
        self.scheduler = ReduceLROnPlateau(self.optimizer, patience=7, threshold=3e-4)

    def set_config(self):
        self.batch_size = BATCH_SIZE
        self.criterion = IndividualLosses(self.main_device)

    @config
    @set_training
    def train(self):

        self.print_args()
        self.audio_dumps_folder = os.path.join(self.audio_dumps_path, self.workname, 'train')
        create_folder(self.audio_dumps_folder)
        self.visual_dumps_folder = os.path.join(self.visual_dumps_path, self.workname, 'train')
        create_folder(self.visual_dumps_folder)

        training_data = UnetInput('train')
        self.train_loader = torch.utils.data.DataLoader(training_data,
                                                        batch_size=BATCH_SIZE,
                                                        shuffle=True,
                                                        num_workers=10)
        self.train_batches = len(self.train_loader)
        validation_data = UnetInput('val')
        self.val_loader = torch.utils.data.DataLoader(validation_data,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=True,
                                                      num_workers=10)
        self.val_batches = len(self.val_loader)

        self.avg_cost = np.zeros([self.EPOCHS, self.K], dtype=np.float32)
        self.lambda_weight = np.ones([len(SOURCES_SUBSET), self.EPOCHS])
        for self.epoch in range(self.start_epoch, self.EPOCHS):
            self.cost = list(torch.zeros(self.K))

            # apply Dynamic Weight Average
            if self.epoch == 0 or self.epoch == 1:
                self.lambda_weight[:, self.epoch] = 1.0
            else:
                if K == 2:
                    self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0]
                    self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1]
                    exp_sum = (np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T))
                    self.lambda_weight[0, self.epoch] = 2 * np.exp(self.w_1 / self.DWA_T) / exp_sum
                    self.lambda_weight[1, self.epoch] = 2 * np.exp(self.w_2 / self.DWA_T) / exp_sum
                elif K == 4:
                    self.w_1 = self.avg_cost[self.epoch - 1, 0] / self.avg_cost[self.epoch - 2, 0]
                    self.w_2 = self.avg_cost[self.epoch - 1, 1] / self.avg_cost[self.epoch - 2, 1]
                    self.w_3 = self.avg_cost[self.epoch - 1, 2] / self.avg_cost[self.epoch - 2, 2]
                    self.w_4 = self.avg_cost[self.epoch - 1, 3] / self.avg_cost[self.epoch - 2, 3]
                    exp_sum = np.exp(self.w_1 / self.DWA_T) + np.exp(self.w_2 / self.DWA_T) + np.exp(
                        self.w_3 / self.DWA_T) + np.exp(self.w_4 / self.DWA_T)
                    self.lambda_weight[0, self.epoch] = 4 * np.exp(self.w_1 / self.DWA_T) / exp_sum
                    self.lambda_weight[1, self.epoch] = 4 * np.exp(self.w_2 / self.DWA_T) / exp_sum
                    self.lambda_weight[2, self.epoch] = 4 * np.exp(self.w_3 / self.DWA_T) / exp_sum
                    self.lambda_weight[3, self.epoch] = 4 * np.exp(self.w_4 / self.DWA_T) / exp_sum

            with train(self):
                self.run_epoch(self.train_iter_logger)
            self.scheduler.step(self.loss)
            with val(self):
                self.run_epoch()

            stop = self.EarlyStopChecker.check_improvement(self.loss_.data.tuple['val'].epoch_array.val,
                                                           self.epoch)
            if stop:
                print('Early Stopping Epoch : [{0}]'.format(self.epoch))
                break
            print(
                'Epoch: {:04d} | TRAIN: {:.4f} {:.4f}'.format(self.epoch,
                                                              self.avg_cost[self.epoch, 0],
                                                              self.avg_cost[self.epoch, 1]))

    def train_epoch(self, logger):
        j = 0
        self.train_iterations = len(iter(self.train_loader))
        with tqdm(self.train_loader, desc='Epoch: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(self):
            for inputs, visualization in pbar:
                try:
                    self.absolute_iter += 1
                    inputs = self._allocate_tensor(inputs)
                    output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs)
                    self.component_losses = self.criterion(output)
                    if K == 2:
                        [self.l1, self.l2] = self.component_losses
                        self.loss_tracker = self.l1 + self.l2
                    elif K == 4:
                        [self.l1, self.l2, self.l3, self.l4] = self.component_losses
                        self.loss_tracker = self.l1 + self.l2 + self.l3 + self.l4
                    self.loss = torch.mean(
                        sum(self.lambda_weight[i, self.epoch] * self.component_losses[i] for i in range(self.K)))
                    self.optimizer.zero_grad()
                    self.loss.backward()
                    self.avg_cost[self.epoch] += torch.stack(
                        self.cost).detach().cpu().clone().numpy() / self.train_batches
                    self.gradients()
                    self.optimizer.step()
                    if K == 2:
                        self.cost = [self.l1, self.l2]
                    elif K == 4:
                        self.cost = [self.l1, self.l2, self.l3, self.l4]
                    pbar.set_postfix(loss=self.loss)
                    self.loss_.data.print_logger(self.epoch, j, self.train_iterations, logger)
                    self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
                    j += 1

                except Exception as e:
                    try:
                        self.save_checkpoint(filename=os.path.join(self.workdir, 'checkpoint_backup.pth'))
                    except:
                        self.err_logger.error('Failed to deal with exception. Couldnt save backup at {0} \n'
                                              .format(os.path.join(self.workdir, 'checkpoint_backup.pth')))
                    self.err_logger.error(str(e))
                    raise e

        for tsi in self.tensor_scalar_items:
            setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state))
        for idx, src in enumerate(SOURCES_SUBSET):
            self.writer.add_scalars('weights', {'W_' + src: self.lambda_weight[idx, self.epoch].item()}, self.epoch)
        self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
        self.save_checkpoint()

    def validate_epoch(self):
        with tqdm(self.val_loader, desc='Validation: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(
                self):
            for inputs, visualization in pbar:
                self.val_iterations += 1
                self.loss_.data.update_timed()
                inputs = self._allocate_tensor(inputs)
                output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs)
                self.component_losses = self.criterion(output)
                if K == 2:
                    [self.l1, self.l2] = self.component_losses
                    self.loss_tracker = self.l1 + self.l2
                elif K == 4:
                    [self.l1, self.l2, self.l3, self.l4] = self.component_losses
                    self.loss_tracker = self.l1 + self.l2 + self.l3 + self.l4
                self.loss = torch.mean(
                    sum(self.lambda_weight[i, self.epoch] * self.component_losses[i] for i in range(self.K)))
                self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
                pbar.set_postfix(loss=self.loss)
        for tsi in self.tensor_scalar_items:
            setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state))
        self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)

    @checkpoint_on_key
    @assert_workdir
    def save_checkpoint(self, filename=None):
        state = {
            'epoch': self.epoch + 1,
            'iter': self.absolute_iter + 1,
            'arch': self.model_version,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'loss': self.loss_,
            'key': self.key,
            'scheduler': self.scheduler.state_dict()
        }
        if filename is None:
            filename = os.path.join(self.workdir, self.checkpoint_name)

        elif isinstance(filename, str):
            filename = os.path.join(self.workdir, filename)
        print('Saving checkpoint at : {}'.format(filename))
        torch.save(state, filename)
        if self.loss_tracker_.data.is_best:
            shutil.copyfile(filename, os.path.join(self.workdir, 'best' + self.checkpoint_name))
        print('Checkpoint saved successfully')

    def tensorboard_writer(self, loss, output, gt, absolute_iter, visualization):
        if self.state == 'train':
            iter_val = absolute_iter
        elif self.state == 'val':
            iter_val = self.val_iterations

        if self.iterating:
            if iter_val % PARAMETER_SAVE_FREQUENCY == 0:
                text = visualization[1]
                self.writer.add_text('Filepath', text[-1], iter_val)
                phase = visualization[0].detach().cpu().clone().numpy()
                gt_mags_sq, pred_mags_sq, gt_mags, mix_mag, gt_masks, pred_masks = output
                if len(text) == BATCH_SIZE:
                    grid_unwarp = self.grid_unwarp
                else:  # for the last batch, where the number of samples are generally lesser than the batch_size
                    grid_unwarp = torch.from_numpy(
                        warpgrid(len(text), NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device)
                pred_masks_linear = linearize_log_freq_scale(pred_masks, grid_unwarp)
                gt_masks_linear = linearize_log_freq_scale(gt_masks, grid_unwarp)
                oracle_spec = (mix_mag * gt_masks_linear)
                pred_spec = (mix_mag * pred_masks_linear)

                for i, sample in enumerate(text):
                    sample_id = os.path.basename(sample)[:-4]
                    folder_name = os.path.basename(os.path.dirname(sample))
                    pred_audio_out_folder = os.path.join(self.audio_dumps_folder, folder_name, sample_id)
                    create_folder(pred_audio_out_folder)
                    visuals_out_folder = os.path.join(self.visual_dumps_folder, folder_name, sample_id)
                    create_folder(visuals_out_folder)

                    for j, source in enumerate(SOURCES_SUBSET):
                        gt_audio = torch.from_numpy(
                            istft_reconstruction(gt_mags.detach().cpu().numpy()[i][j], phase[i][0], HOP_LENGTH))
                        pred_audio = torch.from_numpy(
                            istft_reconstruction(pred_spec.detach().cpu().numpy()[i][j], phase[i][0], HOP_LENGTH))
                        librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'GT_' + source + '.wav'),
                                                 gt_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE)
                        librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'PR_' + source + '.wav'),
                                                 pred_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE)

                        ### PLOTTING MAG SPECTROGRAMS ###
                        save_spectrogram(gt_mags[i][j].unsqueeze(0).detach().cpu(),
                                         os.path.join(visuals_out_folder, source), '_MAG_GT.png')
                        save_spectrogram(oracle_spec[i][j].unsqueeze(0).detach().cpu(),
                                         os.path.join(visuals_out_folder, source), '_MAG_ORACLE.png')
                        save_spectrogram(pred_spec[i][j].unsqueeze(0).detach().cpu(),
                                         os.path.join(visuals_out_folder, source), '_MAG_ESTIMATE.png')

                ### PLOTTING MAG SPECTROGRAMS ###
                plot_spectrogram(self.writer, gt_mags.detach().cpu().view(-1, 1, 512, 256)[:8],
                                 self.state + '_GT_MAG', iter_val)
                plot_spectrogram(self.writer, (pred_masks_linear * mix_mag).detach().cpu().view(-1, 1, 512, 256)[:8],
                                 self.state + '_PRED_MAG', iter_val)

        else:
            if K == 2:
                self.writer.add_scalars(self.state + ' losses_epoch', {'Voice Est Loss': self.l1}, self.epoch)
                self.writer.add_scalars(self.state + ' losses_epoch', {'Acc Est Loss': self.l2}, self.epoch)
            elif K == 4:
                self.writer.add_scalars(self.state + ' losses_epoch', {'Voice Est Loss': self.l1}, self.epoch)
                self.writer.add_scalars(self.state + ' losses_epoch', {'Drums Est Loss': self.l2}, self.epoch)
                self.writer.add_scalars(self.state + ' losses_epoch', {'Bass Est Loss': self.l3}, self.epoch)
                self.writer.add_scalars(self.state + ' losses_epoch', {'Other Est Loss': self.l4}, self.epoch)
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

from pytorch_transformers import XLNetModel, XLNetTokenizer, XLNetForSequenceClassification
from pytorch_transformers import AdamW

from tqdm import tqdm, trange
import pandas as pd
import io
import numpy as np
import matplotlib.pyplot as plt
import datetime

from utils.EarlyStopping import EarlyStopping

es = EarlyStopping()

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

print("device: {}".format(device))

df = pd.read_csv("./data/train.csv", encoding='utf-8')
sentences = df['comment_text'].values

test_df = pd.read_csv("./data/test.csv", encoding='utf-8')
test_sentences = test_df['comment_text']

# define output dataframe
sample = pd.read_csv("./data/sample_submission.csv")

for label in [
Ejemplo n.º 6
0
    optimizer, scheduler = get_optimizer_and_lr_scheduler(
        model,
        args['optimizer'],
        args['learning_rate'],
        args['weight_decay'],
        steps_per_epoch=len(train_loader),
        epochs=args['epochs'])

    save_model_folder = f"../save_model/{args['dataset']}/{args['model_name']}"

    shutil.rmtree(save_model_folder, ignore_errors=True)
    os.makedirs(save_model_folder, exist_ok=True)

    early_stopping = EarlyStopping(patience=args['patience'],
                                   save_model_folder=save_model_folder,
                                   save_model_name=args['model_name'])

    loss_func = nn.CrossEntropyLoss()

    train_steps = 0

    for epoch in range(args['epochs']):
        model.train()

        train_y_trues = []
        train_y_predicts = []
        train_total_loss = 0.0
        train_loader_tqdm = tqdm(train_loader, ncols=120)
        for batch, (input_nodes, output_nodes,
                    blocks) in enumerate(train_loader_tqdm):
    optimizer, scheduler = get_optimizer_and_lr_scheduler(
        model,
        args['optimizer'],
        args['learning_rate'],
        args['weight_decay'],
        steps_per_epoch=len(train_loader),
        epochs=args['epochs'])

    save_model_folder = f"../save_model/{args['dataset']}/{args['model_name']}"

    shutil.rmtree(save_model_folder, ignore_errors=True)
    os.makedirs(save_model_folder, exist_ok=True)

    early_stopping = EarlyStopping(patience=args['patience'],
                                   save_model_folder=save_model_folder,
                                   save_model_name=args['model_name'])

    loss_func = nn.BCELoss()

    train_steps = 0

    best_validate_RMSE, final_result = None, None

    for epoch in range(args['epochs']):
        model.train()

        train_y_trues = []
        train_y_predicts = []
        train_total_loss = 0.0
        train_loader_tqdm = tqdm(train_loader, ncols=120)
Ejemplo n.º 8
0
class Baseline(pytorchfw):
    def __init__(self, model, rootdir, workname, main_device=0, trackgrad=False):
        super(Baseline, self).__init__(model, rootdir, workname, main_device, trackgrad)
        self.audio_dumps_path = os.path.join(DUMPS_FOLDER, 'audio')
        self.visual_dumps_path = os.path.join(DUMPS_FOLDER, 'visuals')
        self.main_device=main_device
        self.grid_unwarp = torch.from_numpy(
            warpgrid(BATCH_SIZE, NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device)
        self.EarlyStopChecker = EarlyStopping(patience=EARLY_STOPPING_PATIENCE)
        self.val_iterations = 0

    def print_args(self):
        setup_logger('log_info', self.workdir + '/info_file.txt',
                     FORMAT="[%(asctime)-15s %(filename)s:%(lineno)d %(funcName)s] %(message)s]")
        logger = logging.getLogger('log_info')
        self.print_info(logger)
        logger.info(f'\r\t Spectrogram data dir: {ROOT_DIR}\r'
                    'TRAINING PARAMETERS: \r\t'
                    f'Run name: {self.workname}\r\t'
                    f'Batch size {BATCH_SIZE} \r\t'
                    f'Optimizer {OPTIMIZER} \r\t'
                    f'Initializer {INITIALIZER} \r\t'
                    f'Epochs {EPOCHS} \r\t'
                    f'LR General: {LR} \r\t'
                    f'SGD Momentum {MOMENTUM} \r\t'
                    f'Weight Decay {WEIGHT_DECAY} \r\t'
                    f'Pre-trained model:  {PRETRAINED} \r'
                    'MODEL PARAMETERS \r\t'
                    f'Nº instruments (K) {K} \r\t'
                    f'U-Net activation: {ACTIVATION} \r\t'
                    f'U-Net Input channels {INPUT_CHANNELS}\r\t'
                    f'U-Net Batch normalization {USE_BN} \r\t')

    def set_optim(self, *args, **kwargs):
        if OPTIMIZER == 'adam':
            return torch.optim.Adam(*args, **kwargs)
        elif OPTIMIZER == 'SGD':
            return torch.optim.SGD(*args, **kwargs)
        else:
            raise Exception('Non considered optimizer. Implement it')

    def hyperparameters(self):
        self.dataparallel = False
        self.initializer = INITIALIZER
        self.EPOCHS = EPOCHS
        self.optimizer = self.set_optim(self.model.parameters(), momentum=MOMENTUM, lr=LR)
        self.LR = LR
        self.scheduler = ReduceLROnPlateau(self.optimizer, patience=7, threshold=3e-4)

    def set_config(self):
        self.batch_size = BATCH_SIZE
        self.criterion = SingleSourceDirectLoss(self.main_device)

    @config
    @set_training
    def train(self):

        self.print_args()
        self.audio_dumps_folder = os.path.join(self.audio_dumps_path, self.workname, 'train')
        create_folder(self.audio_dumps_folder)
        self.visual_dumps_folder = os.path.join(self.visual_dumps_path, self.workname, 'train')
        create_folder(self.visual_dumps_folder)

        self.optimizer = self.set_optim(self.model.parameters(), momentum=MOMENTUM, lr=LR)
        training_data = UnetInput('train')
        self.train_loader = torch.utils.data.DataLoader(training_data,
                                                        batch_size=BATCH_SIZE,
                                                        shuffle=True,
                                                        num_workers=10)

        validation_data = UnetInput('val')
        self.val_loader = torch.utils.data.DataLoader(validation_data,
                                                      batch_size=BATCH_SIZE,
                                                      shuffle=True,
                                                      num_workers=10)
        for self.epoch in range(self.start_epoch, self.EPOCHS):
            with train(self):
                self.run_epoch(self.train_iter_logger)
            self.scheduler.step(self.loss)
            with val(self):
                self.run_epoch()
            stop = self.EarlyStopChecker.check_improvement(self.loss_.data.tuple['val'].epoch_array.val,
                                                           self.epoch)
            if stop:
                print('Early Stopping Epoch : [{0}], '
                      'Best Checkpoint Epoch : [{1}]'.format(self.epoch,
                                                             self.EarlyStopChecker.best_epoch))
                break

    def train_epoch(self, logger):
        j = 0
        self.train_iterations = len(iter(self.train_loader))
        with tqdm(self.train_loader, desc='Epoch: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(self):
            for inputs, visualization in pbar:
                try:
                    self.absolute_iter += 1
                    inputs = self._allocate_tensor(inputs)
                    output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs)
                    self.loss = self.criterion(output)
                    self.optimizer.zero_grad()
                    self.loss.backward()
                    self.gradients()
                    self.optimizer.step()
                    pbar.set_postfix(loss=self.loss)
                    self.loss_.data.print_logger(self.epoch, j, self.train_iterations, logger)
                    self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
                    j += 1
                except Exception as e:
                    try:
                        self.save_checkpoint(filename=os.path.join(self.workdir, 'checkpoint_backup.pth'))
                    except:
                        self.err_logger.error('Failed to deal with exception. Could not save backup at {0} \n'
                                              .format(os.path.join(self.workdir, 'checkpoint_backup.pth')))
                    self.err_logger.error(str(e))
                    raise e
        for tsi in self.tensor_scalar_items:
            setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state))
        self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
        self.save_checkpoint()

    def validate_epoch(self):
        with tqdm(self.val_loader, desc='Validation: [{0}/{1}]'.format(self.epoch, self.EPOCHS)) as pbar, ctx_iter(
                self):
            for inputs, visualization in pbar:
                self.val_iterations += 1
                self.loss_.data.update_timed()
                inputs = self._allocate_tensor(inputs)
                output = self.model(*inputs) if isinstance(inputs, list) else self.model(inputs)
                self.loss = self.criterion(output)
                self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)
                pbar.set_postfix(loss=self.loss)
        for tsi in self.tensor_scalar_items:
            setattr(self, tsi, getattr(self, tsi + '_').data.update_epoch(self.state))
        self.tensorboard_writer(self.loss, output, None, self.absolute_iter, visualization)

    def tensorboard_writer(self, loss, output, gt, absolute_iter, visualization):
        if self.state == 'train':
            iter_val = absolute_iter
        elif self.state == 'val':
            iter_val = self.val_iterations

        if iter_val % PARAMETER_SAVE_FREQUENCY == 0:
            text = visualization[1]
            self.writer.add_text('Filepath', text[-1], iter_val)
            phase = visualization[0].detach().cpu().clone().numpy()
            gt_mags_sq, pred_mags_sq, gt_mags, mix_mag, gt_masks, pred_masks = output
            if len(text) == BATCH_SIZE:
                grid_unwarp = self.grid_unwarp
            else:  # for the last batch, where the number of samples are generally lesser than the batch_size
                grid_unwarp = torch.from_numpy(
                    warpgrid(len(text), NFFT // 2 + 1, STFT_WIDTH, warp=False)).to(self.main_device)
            pred_masks_linear = linearize_log_freq_scale(pred_masks, grid_unwarp)
            gt_masks_linear = linearize_log_freq_scale(gt_masks, grid_unwarp)
            oracle_spec = (mix_mag * gt_masks_linear)
            pred_spec = (mix_mag * pred_masks_linear)
            j = ISOLATED_SOURCE_ID
            source = SOURCES_SUBSET[j]

            for i, sample in enumerate(text):
                sample_id = os.path.basename(sample)[:-4]
                folder_name = os.path.basename(os.path.dirname(sample))
                pred_audio_out_folder = os.path.join(self.audio_dumps_folder, folder_name, sample_id)
                create_folder(pred_audio_out_folder)
                visuals_out_folder = os.path.join(self.visual_dumps_folder, folder_name, sample_id)
                create_folder(visuals_out_folder)

                gt_audio = torch.from_numpy(
                    istft_reconstruction(gt_mags.detach().cpu().numpy()[i][j], phase[i][0], HOP_LENGTH))
                pred_audio = torch.from_numpy(
                    istft_reconstruction(pred_spec.detach().cpu().numpy()[i][0], phase[i][0], HOP_LENGTH))
                librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'GT_' + source + '.wav'),
                                         gt_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE)
                librosa.output.write_wav(os.path.join(pred_audio_out_folder, 'PR_' + source + '.wav'),
                                         pred_audio.cpu().detach().numpy(), TARGET_SAMPLING_RATE)

                ### SAVING MAG SPECTROGRAMS ###
                save_spectrogram(gt_mags[i][j].unsqueeze(0).detach().cpu(),
                                 os.path.join(visuals_out_folder, source), '_MAG_GT.png')
                save_spectrogram(oracle_spec[i][j].unsqueeze(0).detach().cpu(),
                                 os.path.join(visuals_out_folder, source), '_MAG_ORACLE.png')
                save_spectrogram(pred_spec[i][0].unsqueeze(0).detach().cpu(),
                                 os.path.join(visuals_out_folder, source), '_MAG_ESTIMATE.png')

            ### PLOTTING MAG SPECTROGRAMS ON TENSORBOARD ###
            plot_spectrogram(self.writer, gt_mags[:, j].detach().cpu().view(-1, 1, 512, 256)[:8],
                             self.state + '_GT_MAG', iter_val)
            plot_spectrogram(self.writer, (pred_masks_linear * mix_mag).detach().cpu().view(-1, 1, 512, 256)[:8],
                             self.state + '_PRED_MAG', iter_val)
Ejemplo n.º 9
0
def train_model(model, batch_size, patience, n_epochs, gpu):
    
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in range(1, n_epochs + 1):

        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        t = tqdm(iter(train_loader), desc="[Train on Epoch {}/{}]".format(epoch, n_epochs))
        for i, example in enumerate(t): #start at index 0
            # get the inputs
            data = example["image"]
            #print("data size: {}".format(data.size()))
            target = example["fixations"]
            #print("target size: {}".format(target.size()))
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            
            #push data and targets to gpu
            if gpu:
                if torch.cuda.is_available():
                    data = data.to('cuda')
                    target = target.to('cuda')
                    
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            #print("output size: {}".format(output.size()))
            # calculate the loss
            loss = myLoss(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # record training loss
            train_losses.append(loss.item())
            #print("On iteration {} loss is {:.3f}".format(i+1, loss.item()))
            #for the first epoch, plot loss per iteration to have a quick overview of the early training phase
            iteration = i + 1
            #plot is always appending the newest value, so just give the last item if the list
            if epoch == 1:
                plotter_train.plot('loss', 'train', 'Loss per Iteration', iteration, train_losses[-1], batch_size, lr, 'iteration')
            

        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
        t = tqdm(iter(val_loader), desc="[Valid on Epoch {}/{}]".format(epoch, n_epochs))
        for i, example in enumerate(t):
            # get the inputs
            data = example["image"]
            #print("input sum: {}".format(torch.sum(data)))
            target = example["fixations"]
            
            #push data and targets to gpu
            if gpu:
                if torch.cuda.is_available():
                    data = data.to('cuda')
                    target = target.to('cuda')
            
            #print("target sum: {}".format(torch.sum(target)))
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            #print("output sum: {}".format(torch.sum(output)))
            # calculate the loss
            loss = myLoss(output, target)
            # record validation loss
            valid_losses.append(loss.item())
            #plotter_val.plot('loss', 'val', 'Loss per Iteration', iteration, valid_losses[-1])

        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = str(n_epochs)
        
        print_msg = ('[{}/{}] '.format(epoch, epoch_len) +
                     'train_loss: {:.5f} '.format(train_loss) +
                     'valid_loss: {:.5f}'.format(valid_loss))
        
        print(print_msg)
        
        #plot average loss for this epoch
        plotter_eval.plot('loss', 'train', 'Loss per Epoch', epoch, train_loss, batch_size, lr, 'epoch')
        plotter_eval.plot('loss', 'val', 'Loss per Epoch', epoch, valid_loss, batch_size, lr, 'epoch')
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model, batch_size, lr)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    # load the last checkpoint with the best model
    name = "checkpoint_batch_size_{}_lr_{}.pt".format(batch_size, lr)
    model.load_state_dict(torch.load(name))

    return  model, avg_train_losses, avg_valid_losses