コード例 #1
0
ファイル: base_model.py プロジェクト: rvk007/EVA4
class BaseModel(nn.Module):

    def __init__(self):
        """This function instantiates all the model layers."""
        super(BaseModel, self).__init__()
        self.learner = None
    
    def forward(self, x):
        """This function defines the forward pass of the model.

        Args:
            x: Input.
        
        Returns:
            Model output.
        """
        raise NotImplementedError

    def summary(self, input_size):
        """Generates model summary.

        Args:
            input_size (tuple): Size of input to the model.
        """
        torchsummary.summary(self, input_size)

    def fit(
        self, train_loader, optimizer, criterion, device='cpu',
        epochs=1, l1_factor=0.0, val_loader=None, callbacks=None
    ):
        """Train the model.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (str or torch.device): Device where the data
                will be loaded.
            epochs (int, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (float, optional): L1 regularization factor. (default: 0)
            val_loader (torch.utils.data.DataLoader, optional): Validation data
                loader. (default: None)
            callbacks (list, optional): List of callbacks to be used during training.
                (default: None)
            track (str, optional): Can be set to either 'epoch' or 'batch' and will
                store the changes in loss and accuracy for each batch
                or the entire epoch respectively. (default: 'epoch')
        """
        self.learner = Learner(
            self, optimizer, criterion, train_loader, device=device, epochs=epochs,
            val_loader=val_loader, l1_factor=l1_factor, callbacks=callbacks
        )
        self.learner.fit()
コード例 #2
0
ファイル: base_model.py プロジェクト: rbk1988/EVA4-Phase-2
class BaseModel(nn.Module):
    def __init__(self):
        """This function instantiates all the model layers."""
        super(BaseModel, self).__init__()
        self.learner = None

    def forward(self, x):
        """This function defines the forward pass of the model.

        Args:
            x: Input.
        
        Returns:
            Model output.
        """
        raise NotImplementedError

    def summary(self, input_size):
        """Generates model summary.

        Args:
            input_size (tuple): Size of input to the model.
        """
        model_summary(self, input_size)

    def create_learner(self,
                       train_loader,
                       optimizer,
                       criterion,
                       device='cpu',
                       epochs=1,
                       l1_factor=0.0,
                       val_loader=None,
                       callbacks=None,
                       metrics=None,
                       activate_loss_logits=False,
                       record_train=True):
        """Create Learner object.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (str or torch.device): Device where the data
                will be loaded.
            epochs (int, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (float, optional): L1 regularization factor. (default: 0)
            val_loader (torch.utils.data.DataLoader, optional): Validation data
                loader. (default: None)
            callbacks (list, optional): List of callbacks to be used during training.
                (default: None)
            track (str, optional): Can be set to either 'epoch' or 'batch' and will
                store the changes in loss and accuracy for each batch
                or the entire epoch respectively. (default: 'epoch')
            metrics (list of str, optional): List of names of the metrics for model
                evaluation. (default: None)
        """
        self.learner = Learner(train_loader,
                               optimizer,
                               criterion,
                               device=device,
                               epochs=epochs,
                               val_loader=val_loader,
                               l1_factor=l1_factor,
                               callbacks=callbacks,
                               metrics=metrics,
                               activate_loss_logits=activate_loss_logits,
                               record_train=record_train)
        self.learner.set_model(self)

    def set_learner(self, learner):
        """Assign a learner object to the model.

        Args:
            learner (Learner): Learner object.
        """
        self.learner = learner
        self.learner.set_model(self)

    def fit(self, *args, start_epoch=1, **kwargs):
        """Train the model.

        Args:
            start_epoch (int, optional): Start epoch for training.
                (default: 1)
        """

        # Check learner
        if self.learner is None:
            print('Creating a learner object.')
            self.create_learner(*args, **kwargs)

        # Train Model
        self.learner.fit(start_epoch=start_epoch)

    def save_learnable(self, filepath, **kwargs):
        """Save the learnable model.

        Args:
            filepath (str): File in which the model will be saved.
            **kwargs (optional): Additional parameters to save with the model.
        """
        if self.learner is None:
            raise ValueError('Cannot save un-trained model.')

        torch.save(
            {
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.learner.optimizer.state_dict(),
                **kwargs
            }, filepath)

    def save(self, filepath):
        """Save the model.

        Args:
            filepath (str): File in which the model will be saved.
        """
        torch.save(self, filepath)

    def load(self, filepath):
        """Load the model.

        Args:
            filepath (str): File in which the model is be saved.
        
        Returns:
            Parameters saved inside the checkpoint file.
        """
        checkpoint = torch.load(filepath)
        self.load_state_dict(checkpoint['model_state_dict'])
        return {k: v for k, v in checkpoint.items() if k != 'model_state_dict'}
コード例 #3
0
ファイル: base_model.py プロジェクト: shan18/TensorNet
class BaseModel(nn.Module):
    """This is the parent class for all the models that are to be
    created using ``TensorNet``."""
    def __init__(self):
        """This function instantiates all the model layers."""
        super(BaseModel, self).__init__()
        self.learner = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """This function defines the forward pass of the model.

        Args:
            x (torch.Tensor): Input.

        Returns:
            (*torch.Tensor*): Model output.
        """
        raise NotImplementedError

    def summary(self, input_size: Tuple[int]):
        """Generates model summary.

        Args:
            input_size (tuple): Size of input to the model.
        """
        model_summary(self, input_size)

    def create_learner(self,
                       train_loader,
                       optimizer,
                       criterion,
                       device='cpu',
                       epochs=1,
                       l1_factor=0.0,
                       val_loader=None,
                       callbacks=None,
                       metrics=None,
                       activate_loss_logits=False,
                       record_train=True):
        """Create Learner object.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (:obj:`str` or :obj:`torch.device`): Device where the data will be loaded.
            epochs (:obj:`int`, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (:obj:`float`, optional): L1 regularization factor. (default: 0)
            val_loader (:obj:`torch.utils.data.DataLoader`, optional): Validation data loader.
            callbacks (:obj:`list`, optional): List of callbacks to be used during training.
            track (:obj:`str`, optional): Can be set to either `'epoch'` or `'batch'` and will store the
                changes in loss and accuracy for each batch or the entire epoch respectively.
                (default: *'epoch'*)
            metrics (:obj:`list`, optional): List of names of the metrics for model evaluation.
        """
        self.learner = Learner(train_loader,
                               optimizer,
                               criterion,
                               device=device,
                               epochs=epochs,
                               val_loader=val_loader,
                               l1_factor=l1_factor,
                               callbacks=callbacks,
                               metrics=metrics,
                               activate_loss_logits=activate_loss_logits,
                               record_train=record_train)
        self.learner.set_model(self)

    def set_learner(self, learner: Learner):
        """Assign a learner object to the model.

        Args:
            learner (:obj:`Learner`): Learner object.
        """
        self.learner = learner
        self.learner.set_model(self)

    def fit(
        self,
        train_loader,
        optimizer,
        criterion,
        device='cpu',
        epochs=1,
        l1_factor=0.0,
        val_loader=None,
        callbacks=None,
        metrics=None,
        activate_loss_logits=False,
        record_train=True,
        start_epoch=1,
    ):
        """Train the model.

        Args:
            train_loader (torch.utils.data.DataLoader): Training data loader.
            optimizer (torch.optim): Optimizer for the model.
            criterion (torch.nn): Loss Function.
            device (:obj:`str` or :obj:`torch.device`): Device where the data will be loaded.
            epochs (:obj:`int`, optional): Numbers of epochs to train the model. (default: 1)
            l1_factor (:obj:`float`, optional): L1 regularization factor. (default: 0)
            val_loader (:obj:`torch.utils.data.DataLoader`, optional): Validation data loader.
            callbacks (:obj:`list`, optional): List of callbacks to be used during training.
            track (:obj:`str`, optional): Can be set to either `'epoch'` or `'batch'` and will store the
                changes in loss and accuracy for each batch or the entire epoch respectively.
                (default: *'epoch'*)
            metrics (:obj:`list`, optional): List of names of the metrics for model evaluation.
            start_epoch (:obj:`int`, optional): Starting epoch number to display during training.
                (default: 1)
        """

        # Create learner object
        self.create_learner(
            train_loader,
            optimizer,
            criterion,
            device=device,
            epochs=epochs,
            l1_factor=l1_factor,
            val_loader=val_loader,
            callbacks=callbacks,
            metrics=metrics,
            activate_loss_logits=activate_loss_logits,
            record_train=record_train,
        )

        # Train Model
        self.learner.fit(start_epoch=start_epoch)

    def save(self, filepath: str, **kwargs):
        """Save the model.

        Args:
            filepath (str): File in which the model will be saved.
            **kwargs: Additional parameters to save with the model.
        """
        if self.learner is None:
            raise ValueError('Cannot save un-trained model.')

        torch.save(
            {
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.learner.optimizer.state_dict(),
                **kwargs
            }, filepath)

    def load(self, filepath: str) -> dict:
        """Load the model and return the additional parameters saved in
        in the checkpoint file.

        Args:
            filepath (str): File in which the model is be saved.

        Returns:
            (*dict*): Parameters saved inside the checkpoint file.
        """
        checkpoint = torch.load(filepath)
        self.load_state_dict(checkpoint['model_state_dict'])
        return {k: v for k, v in checkpoint.items() if k != 'model_state_dict'}