Example #1
0
def fit(model: Module,
        optimiser: Optimizer,
        loss_fn: Callable,
        epochs: int,
        dataloader: DataLoader,
        prepare_batch: Callable,
        metrics: List[Union[str, Callable]] = None,
        callbacks: List[Callback] = None,
        verbose: bool = True,
        fit_function: Callable = gradient_step,
        fit_function_kwargs: dict = {}):
    """Function to abstract away training loop.

    The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of
    common training functionality provided they are written as a subclass of voicemap.Callback (following the
    Keras API).

    # Arguments
        model: Model to be fitted.
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        epochs: Number of epochs of fitting to be performed
        dataloader: `torch.DataLoader` instance to fit the model to
        prepare_batch: Callable to perform any desired preprocessing
        metrics: Optional list of metrics to evaluate the model with
        callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model
            checkpointing, learning rate scheduling etc... See voicemap.callbacks for more.
        verbose: All print output is muted if this argument is `False`
        fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled
            batches. For more complex training procedures (meta-learning etc...) you will need to write your own
            fit_function
        fit_function_kwargs: Keyword arguments to pass to `fit_function`
    """
    # Determine number of samples:
    num_batches = len(dataloader)
    batch_size = dataloader.batch_size

    callbacks = CallbackList([
        DefaultCallback(),
    ] + (callbacks or []) + [
        ProgressBarLogger(),
    ])
    callbacks.set_model(model)
    callbacks.set_params({
        'num_batches': num_batches,
        'batch_size': batch_size,
        'verbose': verbose,
        'metrics': (metrics or []),
        'prepare_batch': prepare_batch,
        'loss_fn': loss_fn,
        'optimiser': optimiser
    })

    if verbose:
        print('Begin training...')

    callbacks.on_train_begin()

    for epoch in range(1, epochs + 1):
        callbacks.on_epoch_begin(epoch)

        epoch_logs = {}
        for batch_index, batch in enumerate(dataloader):
            batch_logs = dict(batch=batch_index, size=(batch_size or 1))

            callbacks.on_batch_begin(batch_index, batch_logs)

            x, y = prepare_batch(batch)

            loss, y_pred = fit_function(model, optimiser, loss_fn, x, y,
                                        **fit_function_kwargs)

            #print(.shape)
            batch_logs['loss'] = loss.item()

            # Loops through all metrics
            batch_logs = batch_metrics(model, y_pred, y[:, -1:, :], metrics,
                                       batch_logs)

            callbacks.on_batch_end(batch_index, batch_logs)

        # Run on epoch end
        callbacks.on_epoch_end(epoch, epoch_logs)

    # Run on train end
    if verbose:
        print('Finished.')

    callbacks.on_train_end()
def fit(model: Module,
        optimiser: Optimizer,
        loss_fn: Callable,
        epochs: int,
        dataloader: DataLoader,
        writer: SummaryWriter,
        prepare_batch: Callable,
        metrics: List[Union[str, Callable]] = None,
        callbacks: List[Callback] = None,
        verbose: bool = True,
        fit_function: Callable = gradient_step,
        stnmodel=None,
        stnoptim=None,
        args=None,
        fit_function_kwargs: dict = {}):
    """Function to abstract away training loop.

    The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of
    common training functionality provided they are written as a subclass of voicemap.Callback (following the
    Keras API).

    # Arguments
        model: Model to be fitted.
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        epochs: Number of epochs of fitting to be performed
        dataloader: `torch.DataLoader` instance to fit the model to
        writer: `tensorboard.SummaryWriter` instance to write plots to tensorboard
        prepare_batch: Callable to perform any desired preprocessing
        metrics: Optional list of metrics to evaluate the model with
        callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model
            checkpointing, learning rate scheduling etc... See voicemap.callbacks for more.
        verbose: All print output is muted if this argument is `False`
        fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled
            batches. For more complex training procedures (meta-learning etc...) you will need to write your own
            fit_function
        fit_function_kwargs: Keyword arguments to pass to `fit_function`
    """
    # Determine number of samples:
    num_batches = len(dataloader)
    batch_size = dataloader.batch_size

    callbacks = CallbackList([
        DefaultCallback(),
    ] + (callbacks or []) + [
        ProgressBarLogger(),
    ])
    callbacks.set_model(model)
    callbacks.set_params({
        'num_batches': num_batches,
        'batch_size': batch_size,
        'verbose': verbose,
        'metrics': (metrics or []),
        'prepare_batch': prepare_batch,
        'loss_fn': loss_fn,
        'optimiser': optimiser
    })

    if verbose:
        print('Begin training...')

    callbacks.on_train_begin()

    for epoch in range(1, epochs + 1):
        callbacks.on_epoch_begin(epoch)

        epoch_logs = {}
        for batch_index, batch in enumerate(dataloader):
            batch_logs = dict(batch=batch_index, size=(batch_size or 1))

            callbacks.on_batch_begin(batch_index, batch_logs)

            x, y = prepare_batch(batch)

            fit_function_kwargs['stnmodel'] = stnmodel
            fit_function_kwargs['stnoptim'] = stnoptim
            fit_function_kwargs['args'] = args

            n_shot = fit_function_kwargs['n_shot']
            k_way = fit_function_kwargs['k_way']

            loss, y_pred, aug_imgs = fit_function(model, optimiser, loss_fn, x,
                                                  y, **fit_function_kwargs)
            batch_logs['loss'] = loss.item()

            # Useful for viewing images for debugging (Doesn't work for maml)
            #TODO (kamal): customize for maml
            # if batch_index % 100 == 99:
            # writer.add_figure('episode', plot_classes_preds(aug_imgs, n_shot, k_way),
            # global_step=len(dataloader)*(epoch-1) + batch_index)

            # Loops through all metrics
            batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs)

            # Log training loss and categorical accuracy
            writer.add_scalar('Train_loss', batch_logs['loss'],
                              len(dataloader) * (epoch - 1) + batch_index)
            writer.add_scalar('categorical_accuracy',
                              batch_logs['categorical_accuracy'],
                              len(dataloader) * (epoch - 1) + batch_index)

            callbacks.on_batch_end(batch_index, batch_logs)

        # Run on epoch end
        callbacks.on_epoch_end(epoch, epoch_logs)

    # Run on train end
    if verbose:
        print('Finished.')

    callbacks.on_train_end()
Example #3
0
def fit(model: Union[Module, List[Module]],
        optimiser: Optimizer,
        loss_fn: Callable,
        epochs: int,
        dataloader: DataLoader,
        prepare_batch: Callable,
        metrics: List[Union[str, Callable]] = None,
        callbacks: List[Callback] = None,
        verbose: bool = True,
        fit_function: Callable = gradient_step,
        n_models: int = 1,
        fit_function_kwargs: dict = {}):
    """Function to abstract away training loop.

    The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of
    common training functionality provided they are written as a subclass of voicemap.Callback (following the
    Keras API).

    # Arguments
        model: Model to be fitted.
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        epochs: Number of epochs of fitting to be performed
        dataloader: `torch.DataLoader` instance to fit the model to
        prepare_batch: Callable to perform any desired preprocessing
        metrics: Optional list of metrics to evaluate the model with
        callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model
            checkpointing, learning rate scheduling etc... See voicemap.callbacks for more.
        verbose: All print output is muted if this argument is `False`
        fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled
            batches. For more complex training procedures (meta-learning etc...) you will need to write your own
            fit_function
        fit_function_kwargs: Keyword arguments to pass to `fit_function`
    """
    # Determine number of samples:
    num_batches = len(dataloader)
    batch_size = dataloader.batch_size

    fit_function_kwargs_logs = dict(fit_function_kwargs)

    fit_function_kwargs_logs['train'] = False
    fit_function_kwargs_logs['pred_fn'] = logmeanexp_preds

    callbacks = CallbackList([
        DefaultCallback(),
    ] + (callbacks or []) + [
        ProgressBarLogger(),
    ])
    callbacks.set_model(model)
    callbacks.set_params({
        'num_batches': num_batches,
        'batch_size': batch_size,
        'verbose': verbose,
        'metrics': (metrics or []),
        'prepare_batch': prepare_batch,
        'loss_fn': loss_fn,
        'optimiser': optimiser,
        'n_models': n_models
    })

    if verbose:
        print('Begin training...')

    callbacks.on_train_begin()

    for epoch in range(1, epochs + 1):
        callbacks.on_epoch_begin(epoch)

        epoch_logs = {}
        for batch_index, batch in enumerate(dataloader):
            batch_logs = dict(batch=batch_index, size=(batch_size or 1))

            callbacks.on_batch_begin(batch_index, batch_logs)

            x, y = prepare_batch(batch)

            # result = {
            #     "meta_batch_loss": meta_batch_loss,
            #     "task_predictions": task_predictions,
            #     "models_losses": models_losses,
            #     "models_predictions": models_predictions,
            #     "mean_support_loss": mean_support_loss
            # }

            result = fit_function(model, optimiser, loss_fn, x, y,
                                  **fit_function_kwargs)

            loss = result['meta_batch_loss']
            y_pred = result['task_predictions']
            models_losses = result['models_losses']
            models_preds = result['models_predictions']
            support_loss = result['mean_support_loss']

            batch_logs['loss'] = loss.item()
            batch_logs['support_loss'] = support_loss

            # Loops through all metrics
            batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs)

            task_preds = defaultdict(list)
            for model_pred in models_preds:
                for i, task in enumerate(model_pred):
                    task_preds[i].append(task)

            # task_preds : {task_idx : [model_1_pred, model_2_pred, ....] }
            logprobs_pred = []
            logprobs_loss = []
            for task_idx, task_pred in task_preds.items():
                y_pred_ = logmeanexp_preds(task_pred)
                logprobs_pred.append(y_pred_)

            y_pred_logprobs = torch.cat(logprobs_pred)
            # TODO: make it work with MixturePredLoss
            # with torch.no_grad():
            #     loss_logprobs = loss_fn(y_pred_logprobs, y).item()
            #
            # batch_logs['logprobs_loss'] = loss_logprobs
            batch_logs['logprobs_nll'] = nll_loss(y_pred_logprobs,
                                                  y,
                                                  reduction="mean").item()
            batch_logs = batch_metrics(model, y_pred_logprobs, y, metrics,
                                       batch_logs, 'logprobs')

            for i, (loss, y_pred) in enumerate(zip(models_losses,
                                                   models_preds)):
                batch_logs[f'loss_{i}'] = nmean(loss)
                batch_logs[f'categorical_accuracy_{i}'] = NAMED_METRICS[
                    'categorical_accuracy'](y, torch.cat(y_pred))

            callbacks.on_batch_end(batch_index, batch_logs)

        # Run on epoch end
        callbacks.on_epoch_end(epoch, epoch_logs)

    # Run on train end
    if verbose:
        print('Finished.')

    callbacks.on_train_end()
Example #4
0
def fit(model: Module, optimiser: Optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader,
        prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None,
        verbose: bool =True, fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}):
    """Function to abstract away training loop.

    The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of
    common training functionality provided they are written as a subclass of voicemap.Callback (following the
    Keras API).

    # Arguments
        model: Model to be fitted.
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        epochs: Number of epochs of fitting to be performed
        dataloader: `torch.DataLoader` instance to fit the model to
        prepare_batch: Callable to perform any desired preprocessing
        metrics: Optional list of metrics to evaluate the model with
        callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model
            checkpointing, learning rate scheduling etc... See voicemap.callbacks for more.
        verbose: All print output is muted if this argument is `False`
        fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled
            batches. For more complex training procedures (meta-learning etc...) you will need to write your own
            fit_function
        fit_function_kwargs: Keyword arguments to pass to `fit_function`
    """
    # Determine number of samples:
    num_batches = len(dataloader)
    batch_size = dataloader.batch_size

    # default call back averages the bach accuracy and loss
    callbacks = CallbackList([DefaultCallback(), ] + (callbacks or []) + [ProgressBarLogger(), ])
    # model and all other information has been passed to call back nothing else ot be done during function calls
    callbacks.set_model(model)
    callbacks.set_params({
        'num_batches': num_batches,
        'batch_size': batch_size,
        'verbose': verbose,
        'metrics': (metrics or []),
        'prepare_batch': prepare_batch,
        'loss_fn': loss_fn,
        'optimiser': optimiser
    })

    if verbose:
        print('Begin training...')

    # creates a csv logger file
    callbacks.on_train_begin()

    for epoch in range(1, epochs+1):
        callbacks.on_epoch_begin(epoch)

        epoch_logs = {}
        for batch_index, batch in enumerate(dataloader):
            # for each new batch create a batch_log
            batch_logs = dict(batch=batch_index, size=(batch_size or 1))

            # this does nothing for protonets except the progress bar
            callbacks.on_batch_begin(batch_index, batch_logs)
            # y here is of shape queries * k-way
            # y is in [0, k]
            
            x, y = prepare_batch(batch)

            # what we expect here is a loss for the above batch and the probabolities of the classes predicted
            # accuracy is determined on queries only
            loss, y_pred = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs)
            batch_logs['loss'] = loss.item()

            # Loops through all metrics
            # for each episode per epoch what is the accuracy that number of corrects / total number of queries
            batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs)
            # this does nothing protonets except the progess bar
            # the categorical accuracy and loss we see during train refers to the queried samples accuracy
            callbacks.on_batch_end(batch_index, batch_logs)
        # evalfewshot is run only here after one complete epoch 
        # Run on epoch end
        callbacks.on_epoch_end(epoch, epoch_logs)

    # Run on train end
    if verbose:
        print('Finished.')

    callbacks.on_train_end()
Example #5
0
def fit_gan_few_shot(encoder: Module,
                     generator: Module,
                     classifier: Module,
                     discriminator: Module,
                     dataloader: DataLoader,
                     params_str: str,
                     k: int,
                     n: int,
                     epochs: int,
                     prepare_batch: Callable,
                     latent_sizeC: int,
                     latent_sizeB: int,
                     device,
                     e_optimizer: Optimizer,
                     g_optimizer: Optimizer,
                     c_optimizer: Optimizer,
                     d_optimizer: Optimizer,
                     metrics: List[Union[str, Callable]] = None,
                     callbacks: List[Callback] = None,
                     verbose: bool = True,
                     is_complete=True):
    """Function to abstract away training loop.

    The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of
    common training functionality provided they are written as a subclass of voicemap.Callback (following the
    Keras API).

    # Arguments
        model: Model to be fitted.
        optimiser: Optimiser to calculate gradient step from loss
        epochs: Number of epochs of fitting to be performed
        dataloader: `torch.DataLoader` instance to fit the model to
        prepare_batch: Callable to perform any desired preprocessing
        metrics: Optional list of metrics to evaluate the model with
        callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model
            checkpointing, learning rate scheduling etc... See voicemap.callbacks for more.
        verbose: All print output is muted if this argument is `False`
    """
    # Determine number of samples:
    num_batches = len(dataloader)
    batch_size = dataloader.batch_size

    callbacks = CallbackList([
        DefaultCallback(),
    ] + (callbacks or []) + [
        ProgressBarLogger(),
    ])

    class EncoderClassifier(nn.Module):
        def forward(self, x):
            c, v = encoder(x)
            return classifier(c)

    encoderclassifier = EncoderClassifier().to(device, dtype=torch.double)

    callbacks.set_model(encoderclassifier)
    callbacks.set_params({
        'num_batches': num_batches,
        'batch_size': batch_size,
        'verbose': verbose,
        'metrics': (metrics or []),
        'prepare_batch': prepare_batch,
        'loss_fn': criterionClass
    })

    checkpoint_e = ModelCheckpoint(
        filepath=os.path.join(PATH, 'models', 'semantic_gan',
                              str(params_str) + '_encoder.pth'),
        monitor='val_' + str(n) + '-shot_' + str(k) + '-way_acc')

    checkpoint_g = ModelCheckpoint(
        filepath=os.path.join(PATH, 'models', 'semantic_gan',
                              str(params_str) + '_generator.pth'),
        monitor='val_' + str(n) + '-shot_' + str(k) + '-way_acc')

    checkpoint_c = ModelCheckpoint(
        filepath=os.path.join(PATH, 'models', 'semantic_gan',
                              str(params_str) + '_classifier.pth'),
        monitor='val_' + str(n) + '-shot_' + str(k) + '-way_acc')

    checkpoint_d = ModelCheckpoint(
        filepath=os.path.join(PATH, 'models', 'semantic_gan',
                              str(params_str) + '_discriminator.pth'),
        monitor='val_' + str(n) + '-shot_' + str(k) + '-way_acc')

    checkpoint_e.set_model(encoder)
    checkpoint_g.set_model(generator)
    checkpoint_c.set_model(classifier)
    checkpoint_d.set_model(discriminator)

    if verbose:
        print('Begin training...')

    callbacks.on_train_begin()

    for epoch in range(1, epochs + 1):
        callbacks.on_epoch_begin(epoch)

        epoch_logs = {}
        for batch_index, batch in enumerate(dataloader):
            batch_logs = dict(batch=batch_index, size=(batch_size or 1))

            callbacks.on_batch_begin(batch_index, batch_logs)

            x, y = prepare_batch(batch)
            cl_loss, cl_score, dae_loss, ae_loss, cae_loss = gradient_step_gan_few_shot(
                e_optimizer,
                g_optimizer,
                c_optimizer,
                d_optimizer,
                x,
                y,
                device,
                encoder,
                generator,
                classifier,
                discriminator,
                latent_sizeB,
                latent_sizeC,
                epoch,
                is_complete=is_complete)
            batch_logs['cl_loss'] = cl_loss.item()
            #batch_logs['d_loss'] = d_loss.item()
            batch_logs['ae_loss'] = ae_loss.item()
            batch_logs['cae_loss'] = cae_loss.item()

            batch_logs['cl_score'] = cl_score.item()
            #batch_logs['real_score'] = real_score.item()
            #batch_logs['fake_score'] = fake_score.item()
            if is_complete:
                batch_logs['dae_loss'] = dae_loss.item()

            callbacks.on_batch_end(batch_index, batch_logs)

        # Run on epoch end
        callbacks.on_epoch_end(epoch, epoch_logs)
        checkpoint_e.on_epoch_end(epoch, epoch_logs)
        checkpoint_c.on_epoch_end(epoch, epoch_logs)
        checkpoint_g.on_epoch_end(epoch, epoch_logs)
        checkpoint_d.on_epoch_end(epoch, epoch_logs)

    # Run on train end
    if verbose:
        print('Finished.')

    callbacks.on_train_end()
Example #6
0
def fit(model: Module,
        optimiser: Optimizer,
        loss_fn: Callable,
        epochs: int,
        dataloader: DataLoader,
        prepare_batch: Callable,
        metrics: List[Union[str, Callable]] = None,
        callbacks: List[Callback] = None,
        verbose: bool = True,
        fit_function: Callable = gradient_step,
        fit_function_kwargs: dict = {}):
    """Function to abstract away training loop.

    The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of
    common training functionality provided they are written as a subclass of voicemap.Callback (following the
    Keras API).

    # Arguments
        model: Model to be fitted.
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        epochs: Number of epochs of fitting to be performed
        dataloader: `torch.DataLoader` instance to fit the model to
        prepare_batch: Callable to perform any desired preprocessing
        metrics: Optional list of metrics to evaluate the model with
        callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model
            checkpointing, learning rate scheduling etc... See voicemap.callbacks for more.
        verbose: All print output is muted if this argument is `False`
        fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled
            batches. For more complex training procedures (meta-learning etc...) you will need to write your own
            fit_function
        fit_function_kwargs: Keyword arguments to pass to `fit_function`
    """
    # Determine number of samples:
    num_batches = len(dataloader)

    print('num_batches: ', num_batches)

    batch_size = dataloader.batch_size

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    callbacks = CallbackList([
        DefaultCallback(),
    ] + (callbacks or []) + [
        ProgressBarLogger(),
    ])
    callbacks.set_model(model)
    callbacks.set_params({
        'num_batches': num_batches,
        'batch_size': batch_size,
        'verbose': verbose,
        'metrics': (metrics or []),
        'prepare_batch': prepare_batch,
        'loss_fn': loss_fn,
        'optimiser': optimiser
    })

    if verbose:
        print('Begin training...')

    callbacks.on_train_begin()

    for epoch in range(1, epochs + 1):
        callbacks.on_epoch_begin(epoch)

        epoch_logs = {}

        #        train_iter = iter(dataloader)
        #        first_batch = next(train_iter)

        #        for obj in gc.get_objects():
        #            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
        #                del obj
        torch.cuda.empty_cache()

        for batch_index, batch in enumerate(dataloader):

            batch_logs = dict(batch=batch_index, size=(batch_size or 1))

            callbacks.on_batch_begin(batch_index, batch_logs)

            input_ids, attention_mask, label = prepare_batch(batch)
            input_ids = torch.squeeze(input_ids, dim=1)
            attention_mask = torch.squeeze(attention_mask, dim=1)

            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            label = label.to(device)

            #          print('input_ids shape: ', input_ids.size())
            #          print('attention_mask shape: ', attention_mask.size())
            #          print('label shape: ', label.size())

            #          input_ids = input_ids[:8,:]
            #          attention_mask = attention_mask[:8,:]
            #          label = label[:8]

            #          print('input_ids shape: ', input_ids.size())
            #          print('attention_mask shape: ', attention_mask.size())
            #          print('label shape: ', label.size())

            try:
                print('Before Loss/Pred')
                gpu_dict = get_gpu_info()
                print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.
                      format(gpu_dict['mem_total'], gpu_dict['mem_used'],
                             gpu_dict['mem_used_percent']))
            except:
                pass

            loss, y_pred = fit_function(model, optimiser, loss_fn, input_ids,
                                        attention_mask, label,
                                        **fit_function_kwargs)
            batch_logs['loss'] = loss.item()

            try:
                print('After Loss/Pred')
                gpu_dict = get_gpu_info()
                print('Total GPU Mem: {} , Used GPU Mem: {}, Used Percent: {}'.
                      format(gpu_dict['mem_total'], gpu_dict['mem_used'],
                             gpu_dict['mem_used_percent']))
            except:
                pass

            # Loops through all metrics
            batch_logs = batch_metrics(model, y_pred, label, metrics,
                                       batch_logs)

            callbacks.on_batch_end(batch_index, batch_logs)

            #           for obj in gc.get_objects():
            #               if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            #                   del obj
            torch.cuda.empty_cache()

        # Run on epoch end
        callbacks.on_epoch_end(epoch, epoch_logs)

    # Run on train end
    if verbose:
        print('Finished.')

    callbacks.on_train_end()