Example #1
0
    def fit(
        self, train_loader, optimizer, criterion, device='cpu',
        epochs=1, l1_factor=0.0, val_loader=None, callbacks=None
    ):
        """Train the model.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (str or torch.device): Device where the data
                will be loaded.
            epochs (int, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (float, optional): L1 regularization factor. (default: 0)
            val_loader (torch.utils.data.DataLoader, optional): Validation data
                loader. (default: None)
            callbacks (list, optional): List of callbacks to be used during training.
                (default: None)
            track (str, optional): Can be set to either 'epoch' or 'batch' and will
                store the changes in loss and accuracy for each batch
                or the entire epoch respectively. (default: 'epoch')
        """
        self.learner = Learner(
            self, optimizer, criterion, train_loader, device=device, epochs=epochs,
            val_loader=val_loader, l1_factor=l1_factor, callbacks=callbacks
        )
        self.learner.fit()
Example #2
0
class BaseModel(nn.Module):

    def __init__(self):
        """This function instantiates all the model layers."""
        super(BaseModel, self).__init__()
        self.learner = None
    
    def forward(self, x):
        """This function defines the forward pass of the model.

        Args:
            x: Input.
        
        Returns:
            Model output.
        """
        raise NotImplementedError

    def summary(self, input_size):
        """Generates model summary.

        Args:
            input_size (tuple): Size of input to the model.
        """
        torchsummary.summary(self, input_size)

    def fit(
        self, train_loader, optimizer, criterion, device='cpu',
        epochs=1, l1_factor=0.0, val_loader=None, callbacks=None
    ):
        """Train the model.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (str or torch.device): Device where the data
                will be loaded.
            epochs (int, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (float, optional): L1 regularization factor. (default: 0)
            val_loader (torch.utils.data.DataLoader, optional): Validation data
                loader. (default: None)
            callbacks (list, optional): List of callbacks to be used during training.
                (default: None)
            track (str, optional): Can be set to either 'epoch' or 'batch' and will
                store the changes in loss and accuracy for each batch
                or the entire epoch respectively. (default: 'epoch')
        """
        self.learner = Learner(
            self, optimizer, criterion, train_loader, device=device, epochs=epochs,
            val_loader=val_loader, l1_factor=l1_factor, callbacks=callbacks
        )
        self.learner.fit()
Example #3
0
    def create_learner(self,
                       train_loader,
                       optimizer,
                       criterion,
                       device='cpu',
                       epochs=1,
                       l1_factor=0.0,
                       val_loader=None,
                       callbacks=None,
                       metrics=None,
                       activate_loss_logits=False,
                       record_train=True):
        """Create Learner object.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (str or torch.device): Device where the data
                will be loaded.
            epochs (int, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (float, optional): L1 regularization factor. (default: 0)
            val_loader (torch.utils.data.DataLoader, optional): Validation data
                loader. (default: None)
            callbacks (list, optional): List of callbacks to be used during training.
                (default: None)
            track (str, optional): Can be set to either 'epoch' or 'batch' and will
                store the changes in loss and accuracy for each batch
                or the entire epoch respectively. (default: 'epoch')
            metrics (list of str, optional): List of names of the metrics for model
                evaluation. (default: None)
        """
        self.learner = Learner(train_loader,
                               optimizer,
                               criterion,
                               device=device,
                               epochs=epochs,
                               val_loader=val_loader,
                               l1_factor=l1_factor,
                               callbacks=callbacks,
                               metrics=metrics,
                               activate_loss_logits=activate_loss_logits,
                               record_train=record_train)
        self.learner.set_model(self)
Example #4
0
class BaseModel(nn.Module):
    def __init__(self):
        """This function instantiates all the model layers."""
        super(BaseModel, self).__init__()
        self.learner = None

    def forward(self, x):
        """This function defines the forward pass of the model.

        Args:
            x: Input.
        
        Returns:
            Model output.
        """
        raise NotImplementedError

    def summary(self, input_size):
        """Generates model summary.

        Args:
            input_size (tuple): Size of input to the model.
        """
        model_summary(self, input_size)

    def create_learner(self,
                       train_loader,
                       optimizer,
                       criterion,
                       device='cpu',
                       epochs=1,
                       l1_factor=0.0,
                       val_loader=None,
                       callbacks=None,
                       metrics=None,
                       activate_loss_logits=False,
                       record_train=True):
        """Create Learner object.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (str or torch.device): Device where the data
                will be loaded.
            epochs (int, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (float, optional): L1 regularization factor. (default: 0)
            val_loader (torch.utils.data.DataLoader, optional): Validation data
                loader. (default: None)
            callbacks (list, optional): List of callbacks to be used during training.
                (default: None)
            track (str, optional): Can be set to either 'epoch' or 'batch' and will
                store the changes in loss and accuracy for each batch
                or the entire epoch respectively. (default: 'epoch')
            metrics (list of str, optional): List of names of the metrics for model
                evaluation. (default: None)
        """
        self.learner = Learner(train_loader,
                               optimizer,
                               criterion,
                               device=device,
                               epochs=epochs,
                               val_loader=val_loader,
                               l1_factor=l1_factor,
                               callbacks=callbacks,
                               metrics=metrics,
                               activate_loss_logits=activate_loss_logits,
                               record_train=record_train)
        self.learner.set_model(self)

    def set_learner(self, learner):
        """Assign a learner object to the model.

        Args:
            learner (Learner): Learner object.
        """
        self.learner = learner
        self.learner.set_model(self)

    def fit(self, *args, start_epoch=1, **kwargs):
        """Train the model.

        Args:
            start_epoch (int, optional): Start epoch for training.
                (default: 1)
        """

        # Check learner
        if self.learner is None:
            print('Creating a learner object.')
            self.create_learner(*args, **kwargs)

        # Train Model
        self.learner.fit(start_epoch=start_epoch)

    def save_learnable(self, filepath, **kwargs):
        """Save the learnable model.

        Args:
            filepath (str): File in which the model will be saved.
            **kwargs (optional): Additional parameters to save with the model.
        """
        if self.learner is None:
            raise ValueError('Cannot save un-trained model.')

        torch.save(
            {
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.learner.optimizer.state_dict(),
                **kwargs
            }, filepath)

    def save(self, filepath):
        """Save the model.

        Args:
            filepath (str): File in which the model will be saved.
        """
        torch.save(self, filepath)

    def load(self, filepath):
        """Load the model.

        Args:
            filepath (str): File in which the model is be saved.
        
        Returns:
            Parameters saved inside the checkpoint file.
        """
        checkpoint = torch.load(filepath)
        self.load_state_dict(checkpoint['model_state_dict'])
        return {k: v for k, v in checkpoint.items() if k != 'model_state_dict'}
Example #5
0
class LRFinder:
    """Learning rate range test.
    The learning rate range test increases the learning rate in a pre-training run
    between two boundaries in a linear or exponential manner. It provides valuable
    information on how well the network can be trained over a range of learning rates
    and what is the optimal learning rate.

    Args:
        model (torch.nn.Module): Model Instance.
        optimizer (torch.optim): Optimizer where the defined learning
            is assumed to be the lower boundary of the range test.
        criterion (torch.nn): Loss function.
        metric (str, optional): Metric to use for finding the best learning rate. Can
            be either 'loss' or 'accuracy'. (default: 'loss')
        device (str or torch.device, optional): Device where the computation
            will take place. If None, uses the same device as `model`. (default: none)
        memory_cache (bool, optional): If this flag is set to True, state_dict of
            model and optimizer will be cached in memory. Otherwise, they will be saved
            to files under the `cache_dir`. (default: True)
        cache_dir (str, optional): Path for storing temporary files. If no path is
            specified, system-wide temporary directory is used. Notice that this
            parameter will be ignored if `memory_cache` is True. (default: None)
    """
    def __init__(
        self,
        model,
        optimizer,
        criterion,
        metric='loss',
        device=None,
        memory_cache=True,
        cache_dir=None,
    ):
        # Parameter validation

        # Check if correct 'metric' has been given
        if not metric in ['loss', 'accuracy']:
            raise ValueError(
                f'For "metric" expected one of (loss, accuracy), got {metric}')

        # Check if the optimizer is already attached to a scheduler
        self.optimizer = optimizer
        self._check_for_scheduler()

        self.model = model
        self.criterion = criterion
        self.metric = metric
        self.history = {'lr': [], 'metric': []}
        self.best_metric = None
        self.best_lr = None
        self.memory_cache = memory_cache
        self.cache_dir = cache_dir
        self.learner = None

        # Save the original state of the model and optimizer so they can be restored if
        # needed
        self.model_device = next(self.model.parameters()).device
        self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir)
        self.state_cacher.store('model', self.model.state_dict())
        self.state_cacher.store('optimizer', self.optimizer.state_dict())

        # If device is None, use the same as the model
        self.device = self.model_device if not device else device

    def reset(self):
        """Restores the model and optimizer to their initial states."""
        self.model.load_state_dict(self.state_cacher.retrieve('model'))
        self.optimizer.load_state_dict(self.state_cacher.retrieve('optimizer'))
        self.model.to(self.model_device)

        if not self.learner is None:
            self.learner.reset_history()

    def _check_for_scheduler(self):
        """Check if the optimizer has and existing scheduler attached to it."""
        for param_group in self.optimizer.param_groups:
            if 'initial_lr' in param_group:
                raise RuntimeError(
                    'Optimizer already has a scheduler attached to it')

    def _set_learning_rate(self, new_lrs):
        """Set the given learning rates in the optimizer."""
        if not isinstance(new_lrs, list):
            new_lrs = [new_lrs] * len(self.optimizer.param_groups)
        if len(new_lrs) != len(self.optimizer.param_groups):
            raise ValueError(
                'Length of new_lrs is not equal to the number of parameter groups in the given optimizer'
            )

        # Set the learning rates to the parameter groups
        for param_group, new_lr in zip(self.optimizer.param_groups, new_lrs):
            param_group['lr'] = new_lr

    def range_test(
        self,
        train_loader,
        iterations,
        mode='iteration',
        learner=None,
        val_loader=None,
        start_lr=None,
        end_lr=10,
        step_mode='exp',
        smooth_f=0.0,
        diverge_th=5,
    ):
        """Performs the learning rate range test.

        Args:
            train_loader (torch.utils.data.DataLoader): The training set data loader.
            iterations (int): The number of iterations/epochs over which the test occurs.
                If 'mode' is set to 'iteration' then it will correspond to the
                number of iterations else if mode is set to 'epoch' then it will correspond
                to the number of epochs.
            mode (str, optional): After which mode to update the learning rate. Can be
                either 'iteration' or 'epoch'. (default: 'iteration') 
            learner (Learner, optional): Learner object for the model. (default: None) 
            val_loader (torch.utils.data.DataLoader, optional): If None, the range test
                will only use the training metric. When given a data loader, the model is
                evaluated after each iteration on that dataset and the evaluation metric
                is used. Note that in this mode the test takes significantly longer but
                generally produces more precise results. (default: None)
            start_lr (float, optional): The starting learning rate for the range test.
                If None, uses the learning rate from the optimizer. (default: None)
            end_lr (float, optional): The maximum learning rate to test. (default: 10)
            step_mode (str, optional): One of the available learning rate policies,
                linear or exponential ('linear', 'exp'). (default: 'exp')
            smooth_f (float, optional): The metric smoothing factor within the [0, 1]
                interval. Disabled if set to 0, otherwise the metric is smoothed using
                exponential smoothing. (default: 0.0)
            diverge_th (int, optional): The test is stopped when the metric surpasses the
                threshold: diverge_th * best_metric. To disable, set it to 0. (default: 5)
        """

        # Check if correct 'mode' mode has been given
        if not mode in ['iteration', 'epoch']:
            raise ValueError(
                f'For "mode" expected one of (iteration, epoch), got {mode}')

        # Reset test results
        self.history = {'lr': [], 'metric': []}
        self.best_metric = None
        self.best_lr = None

        # Check if the optimizer is already attached to a scheduler
        self._check_for_scheduler()

        # Set the starting learning rate
        if start_lr:
            self._set_learning_rate(start_lr)

        # Initialize the proper learning rate policy
        if step_mode.lower() == 'exp':
            lr_schedule = ExponentialLR(self.optimizer, end_lr, iterations)
        elif step_mode.lower() == 'linear':
            lr_schedule = LinearLR(self.optimizer, end_lr, iterations)
        else:
            raise ValueError(f'Expected one of (exp, linear), got {step_mode}')

        if smooth_f < 0 or smooth_f >= 1:
            raise ValueError('smooth_f is outside the range [0, 1]')

        # Set accuracy metric if needed
        metrics = None
        if self.metric == 'accuracy':
            metrics = ['accuracy']

        # Get the learner object
        if not learner is None:
            self.learner = learner(train_loader,
                                   self.optimizer,
                                   self.criterion,
                                   device=self.device,
                                   val_loader=val_loader,
                                   metrics=metrics)
        else:
            self.learner = Learner(train_loader,
                                   self.optimizer,
                                   self.criterion,
                                   device=self.device,
                                   val_loader=val_loader,
                                   metrics=metrics)
        self.learner.set_model(self.model)

        train_iterator = InfiniteDataLoader(train_loader)
        pbar = ProgressBar(target=iterations, width=8)
        if mode == 'iteration':
            print(mode.title() + 's')
        for iteration in range(iterations):
            # Train model
            if mode == 'epoch':
                print(f'{mode.title()} {iteration + 1}:')
            self._train_model(mode, train_iterator)
            if val_loader:
                self.learner.validate(verbose=False)

            # Get metric value
            metric_value = self._get_metric(val_loader)

            # Update the learning rate
            lr_schedule.step()
            self.history['lr'].append(lr_schedule.get_lr()[0])

            # Track the best metric and smooth it if smooth_f is specified
            if iteration == 0:
                self.best_metric = metric_value
                self.best_lr = self.history['lr'][-1]
            else:
                if smooth_f > 0:
                    metric_value = smooth_f * metric_value + (
                        1 - smooth_f) * self.history['metric'][-1]
                if ((self.metric == 'loss' and metric_value < self.best_metric)
                        or (self.metric == 'accuracy'
                            and metric_value > self.best_metric)):
                    self.best_metric = metric_value
                    self.best_lr = self.history['lr'][-1]

            # Check if the metric has diverged; if it has, stop the test
            self.history['metric'].append(metric_value)
            metric_value = self._display_metric_value(metric_value)
            if (diverge_th > 0
                    and ((self.metric == 'loss'
                          and metric_value > self.best_metric * diverge_th) or
                         (self.metric == 'accuracy'
                          and metric_value < self.best_metric / diverge_th))):
                if mode == 'iteration':
                    pbar.update(iterations - 1,
                                values=[('lr', self.history['lr'][-1]),
                                        (self.metric.title(), metric_value)])
                print('\nStopping early, the loss has diverged.')
                break
            else:
                if mode == 'epoch':
                    lr = self.history['lr'][-1]
                    print(
                        f'Learning Rate: {lr:.4f}, {self.metric.title()}: {metric_value:.2f}\n'
                    )
                elif mode == 'iteration':
                    pbar.update(iteration,
                                values=[('lr', self.history['lr'][-1]),
                                        (self.metric.title(), metric_value)])

        metric = self._display_metric_value(self.best_metric)
        if mode == 'epoch':
            print(
                f'Learning Rate: {self.best_lr:.4f}, {self.metric.title()}: {metric:.2f}\n'
            )
        elif mode == 'iteration':
            pbar.add(1,
                     values=[('lr', self.best_lr),
                             (self.metric.title(), metric)])
        print('Learning rate search finished.')

    def _train_model(self, mode, train_iterator):
        if mode == 'iteration':
            self.learner.model.train()
            data, targets = train_iterator.get_batch()
            loss = self.learner.train_batch((data, targets))
            self.learner.update_training_history(loss)
        elif mode == 'epoch':
            self.learner.train_epoch()

    def _get_metric(self, validation=None):
        if self.metric == 'loss':
            if validation:
                return self.learner.val_losses[-1]
            return self.learner.train_losses[-1]
        elif self.metric == 'accuracy':
            if validation:
                return self.learner.val_metrics[0][-1] / 100
            return self.learner.train_metrics[0][-1] / 100

    def _display_metric_value(self, value):
        if self.metric == 'accuracy':
            return value * 100
        return value

    def plot(self, log_lr=True, show_lr=None):
        """Plots the learning rate range test.

        Args:
            skip_start (int, optional): Number of batches to trim from the start.
                (default: 10)
            skip_end (int, optional): Number of batches to trim from the end.
                (default: 5)
            log_lr (bool, optional): True to plot the learning rate in a logarithmic
                scale; otherwise, plotted in a linear scale. (default: True)
            show_lr (float, optional): Is set, will add vertical line to visualize
                specified learning rate. (default: None)
        """

        if show_lr is not None and not isinstance(show_lr, float):
            raise ValueError("show_lr must be float")

        # Get the data to plot from the history dictionary.
        lrs = self.history['lr']
        metrics = self.history['metric']

        # Plot metric_value as a function of the learning rate
        plt.plot(lrs, metrics)
        if log_lr:
            plt.xscale('log')
        plt.xlabel('Learning rate')
        plt.ylabel(self.metric.title())

        if show_lr is not None:
            plt.axvline(x=show_lr, color='red')
        plt.show()
Example #6
0
    def range_test(
        self,
        train_loader,
        iterations,
        mode='iteration',
        learner=None,
        val_loader=None,
        start_lr=None,
        end_lr=10,
        step_mode='exp',
        smooth_f=0.0,
        diverge_th=5,
    ):
        """Performs the learning rate range test.

        Args:
            train_loader (torch.utils.data.DataLoader): The training set data loader.
            iterations (int): The number of iterations/epochs over which the test occurs.
                If 'mode' is set to 'iteration' then it will correspond to the
                number of iterations else if mode is set to 'epoch' then it will correspond
                to the number of epochs.
            mode (str, optional): After which mode to update the learning rate. Can be
                either 'iteration' or 'epoch'. (default: 'iteration') 
            learner (Learner, optional): Learner object for the model. (default: None) 
            val_loader (torch.utils.data.DataLoader, optional): If None, the range test
                will only use the training metric. When given a data loader, the model is
                evaluated after each iteration on that dataset and the evaluation metric
                is used. Note that in this mode the test takes significantly longer but
                generally produces more precise results. (default: None)
            start_lr (float, optional): The starting learning rate for the range test.
                If None, uses the learning rate from the optimizer. (default: None)
            end_lr (float, optional): The maximum learning rate to test. (default: 10)
            step_mode (str, optional): One of the available learning rate policies,
                linear or exponential ('linear', 'exp'). (default: 'exp')
            smooth_f (float, optional): The metric smoothing factor within the [0, 1]
                interval. Disabled if set to 0, otherwise the metric is smoothed using
                exponential smoothing. (default: 0.0)
            diverge_th (int, optional): The test is stopped when the metric surpasses the
                threshold: diverge_th * best_metric. To disable, set it to 0. (default: 5)
        """

        # Check if correct 'mode' mode has been given
        if not mode in ['iteration', 'epoch']:
            raise ValueError(
                f'For "mode" expected one of (iteration, epoch), got {mode}')

        # Reset test results
        self.history = {'lr': [], 'metric': []}
        self.best_metric = None
        self.best_lr = None

        # Check if the optimizer is already attached to a scheduler
        self._check_for_scheduler()

        # Set the starting learning rate
        if start_lr:
            self._set_learning_rate(start_lr)

        # Initialize the proper learning rate policy
        if step_mode.lower() == 'exp':
            lr_schedule = ExponentialLR(self.optimizer, end_lr, iterations)
        elif step_mode.lower() == 'linear':
            lr_schedule = LinearLR(self.optimizer, end_lr, iterations)
        else:
            raise ValueError(f'Expected one of (exp, linear), got {step_mode}')

        if smooth_f < 0 or smooth_f >= 1:
            raise ValueError('smooth_f is outside the range [0, 1]')

        # Set accuracy metric if needed
        metrics = None
        if self.metric == 'accuracy':
            metrics = ['accuracy']

        # Get the learner object
        if not learner is None:
            self.learner = learner(train_loader,
                                   self.optimizer,
                                   self.criterion,
                                   device=self.device,
                                   val_loader=val_loader,
                                   metrics=metrics)
        else:
            self.learner = Learner(train_loader,
                                   self.optimizer,
                                   self.criterion,
                                   device=self.device,
                                   val_loader=val_loader,
                                   metrics=metrics)
        self.learner.set_model(self.model)

        train_iterator = InfiniteDataLoader(train_loader)
        pbar = ProgressBar(target=iterations, width=8)
        if mode == 'iteration':
            print(mode.title() + 's')
        for iteration in range(iterations):
            # Train model
            if mode == 'epoch':
                print(f'{mode.title()} {iteration + 1}:')
            self._train_model(mode, train_iterator)
            if val_loader:
                self.learner.validate(verbose=False)

            # Get metric value
            metric_value = self._get_metric(val_loader)

            # Update the learning rate
            lr_schedule.step()
            self.history['lr'].append(lr_schedule.get_lr()[0])

            # Track the best metric and smooth it if smooth_f is specified
            if iteration == 0:
                self.best_metric = metric_value
                self.best_lr = self.history['lr'][-1]
            else:
                if smooth_f > 0:
                    metric_value = smooth_f * metric_value + (
                        1 - smooth_f) * self.history['metric'][-1]
                if ((self.metric == 'loss' and metric_value < self.best_metric)
                        or (self.metric == 'accuracy'
                            and metric_value > self.best_metric)):
                    self.best_metric = metric_value
                    self.best_lr = self.history['lr'][-1]

            # Check if the metric has diverged; if it has, stop the test
            self.history['metric'].append(metric_value)
            metric_value = self._display_metric_value(metric_value)
            if (diverge_th > 0
                    and ((self.metric == 'loss'
                          and metric_value > self.best_metric * diverge_th) or
                         (self.metric == 'accuracy'
                          and metric_value < self.best_metric / diverge_th))):
                if mode == 'iteration':
                    pbar.update(iterations - 1,
                                values=[('lr', self.history['lr'][-1]),
                                        (self.metric.title(), metric_value)])
                print('\nStopping early, the loss has diverged.')
                break
            else:
                if mode == 'epoch':
                    lr = self.history['lr'][-1]
                    print(
                        f'Learning Rate: {lr:.4f}, {self.metric.title()}: {metric_value:.2f}\n'
                    )
                elif mode == 'iteration':
                    pbar.update(iteration,
                                values=[('lr', self.history['lr'][-1]),
                                        (self.metric.title(), metric_value)])

        metric = self._display_metric_value(self.best_metric)
        if mode == 'epoch':
            print(
                f'Learning Rate: {self.best_lr:.4f}, {self.metric.title()}: {metric:.2f}\n'
            )
        elif mode == 'iteration':
            pbar.add(1,
                     values=[('lr', self.best_lr),
                             (self.metric.title(), metric)])
        print('Learning rate search finished.')
Example #7
0
class BaseModel(nn.Module):
    """This is the parent class for all the models that are to be
    created using ``TensorNet``."""
    def __init__(self):
        """This function instantiates all the model layers."""
        super(BaseModel, self).__init__()
        self.learner = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """This function defines the forward pass of the model.

        Args:
            x (torch.Tensor): Input.

        Returns:
            (*torch.Tensor*): Model output.
        """
        raise NotImplementedError

    def summary(self, input_size: Tuple[int]):
        """Generates model summary.

        Args:
            input_size (tuple): Size of input to the model.
        """
        model_summary(self, input_size)

    def create_learner(self,
                       train_loader,
                       optimizer,
                       criterion,
                       device='cpu',
                       epochs=1,
                       l1_factor=0.0,
                       val_loader=None,
                       callbacks=None,
                       metrics=None,
                       activate_loss_logits=False,
                       record_train=True):
        """Create Learner object.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (:obj:`str` or :obj:`torch.device`): Device where the data will be loaded.
            epochs (:obj:`int`, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (:obj:`float`, optional): L1 regularization factor. (default: 0)
            val_loader (:obj:`torch.utils.data.DataLoader`, optional): Validation data loader.
            callbacks (:obj:`list`, optional): List of callbacks to be used during training.
            track (:obj:`str`, optional): Can be set to either `'epoch'` or `'batch'` and will store the
                changes in loss and accuracy for each batch or the entire epoch respectively.
                (default: *'epoch'*)
            metrics (:obj:`list`, optional): List of names of the metrics for model evaluation.
        """
        self.learner = Learner(train_loader,
                               optimizer,
                               criterion,
                               device=device,
                               epochs=epochs,
                               val_loader=val_loader,
                               l1_factor=l1_factor,
                               callbacks=callbacks,
                               metrics=metrics,
                               activate_loss_logits=activate_loss_logits,
                               record_train=record_train)
        self.learner.set_model(self)

    def set_learner(self, learner: Learner):
        """Assign a learner object to the model.

        Args:
            learner (:obj:`Learner`): Learner object.
        """
        self.learner = learner
        self.learner.set_model(self)

    def fit(
        self,
        train_loader,
        optimizer,
        criterion,
        device='cpu',
        epochs=1,
        l1_factor=0.0,
        val_loader=None,
        callbacks=None,
        metrics=None,
        activate_loss_logits=False,
        record_train=True,
        start_epoch=1,
    ):
        """Train the model.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (:obj:`str` or :obj:`torch.device`): Device where the data will be loaded.
            epochs (:obj:`int`, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (:obj:`float`, optional): L1 regularization factor. (default: 0)
            val_loader (:obj:`torch.utils.data.DataLoader`, optional): Validation data loader.
            callbacks (:obj:`list`, optional): List of callbacks to be used during training.
            track (:obj:`str`, optional): Can be set to either `'epoch'` or `'batch'` and will store the
                changes in loss and accuracy for each batch or the entire epoch respectively.
                (default: *'epoch'*)
            metrics (:obj:`list`, optional): List of names of the metrics for model evaluation.
            start_epoch (:obj:`int`, optional): Starting epoch number to display during training.
                (default: 1)
        """

        # Create learner object
        self.create_learner(
            train_loader,
            optimizer,
            criterion,
            device=device,
            epochs=epochs,
            l1_factor=l1_factor,
            val_loader=val_loader,
            callbacks=callbacks,
            metrics=metrics,
            activate_loss_logits=activate_loss_logits,
            record_train=record_train,
        )

        # Train Model
        self.learner.fit(start_epoch=start_epoch)

    def save(self, filepath: str, **kwargs):
        """Save the model.

        Args:
            filepath (str): File in which the model will be saved.
            **kwargs: Additional parameters to save with the model.
        """
        if self.learner is None:
            raise ValueError('Cannot save un-trained model.')

        torch.save(
            {
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.learner.optimizer.state_dict(),
                **kwargs
            }, filepath)

    def load(self, filepath: str) -> dict:
        """Load the model and return the additional parameters saved in
        in the checkpoint file.

        Args:
            filepath (str): File in which the model is be saved.

        Returns:
            (*dict*): Parameters saved inside the checkpoint file.
        """
        checkpoint = torch.load(filepath)
        self.load_state_dict(checkpoint['model_state_dict'])
        return {k: v for k, v in checkpoint.items() if k != 'model_state_dict'}