Exemple #1
0
class VAE(LightningModule):
    def __init__(
        self,
        hparams=None,
    ):
        """
        Convolutional Variational Autoencoder
        Parameters to be included in hparams
        -----------
        input_width: int
            input image width for image (must be even)
            Default: 28 
        input_height: int
            input image height for image (must be even)
            Default: 28 
        hidden_dim: int
            hidden layer dimension
            Default: 128
        latent_dim: int
            latent layer dimension
            Default: 32
        batch_size: int
            Batch Size for training
            Default: 32
        opt : str
            One of 'adam' or 'adamax' or 'rmsprop'.
            Default : 'adam'
        lr: float
            Learning rate for optimizer.
            Default : 0.001
        weight_decay: float
            Weight decay in optimizer.
            Default : 0   
        """
        super().__init__()
        # attach hparams to log hparams to the loggers (like tensorboard)
        self.__check_hparams(hparams)
        self.hparams = hparams

        # NOTE Change dataloaders appropriately
        self.dataloaders = MNISTDataLoaders(save_path=os.getcwd())
        self.telegrad_logs = {
        }  # log everything you want to be reported via telegram here

        self.encoder = self.init_encoder(self.hidden_dim, self.latent_dim,
                                         self.input_width, self.input_height)
        self.decoder = self.init_decoder(self.hidden_dim, self.latent_dim,
                                         self.input_width, self.input_height)

    def __check_hparams(self, hparams):
        self.hidden_dim = hparams.hidden_dim if hasattr(hparams,
                                                        'hidden_dim') else 128
        self.latent_dim = hparams.latent_dim if hasattr(hparams,
                                                        'latent_dim') else 32
        self.input_width = hparams.input_width if hasattr(
            hparams, 'input_width') else 28
        self.input_height = hparams.input_height if hasattr(
            hparams, 'input_height') else 28
        self.opt = hparams.opt if hasattr(hparams, 'opt') else 'adam'
        self.batch_size = hparams.batch_size if hasattr(hparams,
                                                        'batch_size') else 32
        self.lr = hparams.lr if hasattr(hparams, 'lr') else 0.001
        self.weight_decay = hparams.weight_decay if hasattr(
            hparams, 'weight_decay') else 0

    def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
        encoder = Encoder(hidden_dim, latent_dim, input_width, input_height)
        return encoder

    def init_decoder(self, hidden_dim, latent_dim, input_width, input_height):
        decoder = Decoder(hidden_dim, latent_dim, input_width, input_height)
        return decoder

    def get_prior(self, z_mu, z_std):
        # Prior ~ Normal(0,1)
        P = distributions.normal.Normal(loc=torch.zeros_like(z_mu),
                                        scale=torch.ones_like(z_std))
        return P

    def get_approx_posterior(self, z_mu, z_std):
        # Approx Posterior ~ Normal(mu, sigma)
        Q = distributions.normal.Normal(loc=z_mu, scale=z_std)
        return Q

    def elbo_loss(self, x, P, Q):
        # Reconstruction loss
        z = Q.rsample()
        pxz = self(z)
        recon_loss = F.binary_cross_entropy(pxz, x, reduction='none')

        # sum across dimensions because sum of log probabilities of iid univariate gaussians is the same as
        # multivariate gaussian
        recon_loss = recon_loss.sum(dim=-1)

        # KL divergence loss
        log_qz = Q.log_prob(z)
        log_pz = P.log_prob(z)
        kl_div = (log_qz - log_pz).sum(dim=1)

        # ELBO = reconstruction + KL
        loss = recon_loss + kl_div

        # average over batch
        loss = loss.mean()
        recon_loss = recon_loss.mean()
        kl_div = kl_div.mean()

        return loss, recon_loss, kl_div, pxz

    def forward(self, z):
        return self.decoder(z)

    def _run_step(self, batch):
        x, _ = batch
        z_mu, z_log_var = self.encoder(x)
        z_std = torch.exp(z_log_var / 2)

        P = self.get_prior(z_mu, z_std)
        Q = self.get_approx_posterior(z_mu, z_std)

        x = x.view(x.size(0), -1)

        loss, recon_loss, kl_div, pxz = self.elbo_loss(x, P, Q)

        return loss, recon_loss, kl_div, pxz

    def training_step(self, batch, batch_idx):
        loss, recon_loss, kl_div, pxz = self._run_step(batch)

        logs = {
            'train_elbo_loss': loss,
            'train_recon_loss': recon_loss,
            'train_kl_loss': kl_div
        }
        return {'loss': loss, 'log': logs}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['train_elbo_loss'] for x in outputs]).mean()
        recon_loss = torch.stack([x['train_recon_loss']
                                  for x in outputs]).mean()
        kl_loss = torch.stack([x['train_kl_loss'] for x in outputs]).mean()

        logs = {
            'train_elbo_loss_epoch': avg_loss,
            'val_recon_loss_epoch': recon_loss,
            'val_kl_loss_epoch': kl_loss
        }
        self.telegrad_logs['lr'] = self.lr  # for telegram bot
        self.telegrad_logs['trainer_loss_epoch'] = avg_loss.item(
        )  # for telegram bot
        self.telegrad_logs['train_recon_loss_epoch'] = recon_loss.item(
        )  # for telegram bot
        self.telegrad_logs['train_kl_loss_epoch'] = kl_loss.item(
        )  # for telegram bot
        self.logger.log_metrics({'learning_rate':
                                 self.lr})  # if lr is changed by telegram bot
        return {'avg_train_loss': avg_loss, 'log': logs}

    def validation_step(self, batch, batch_idx):
        loss, recon_loss, kl_div, pxz = self._run_step(batch)

        return {
            'val_loss': loss,
            'val_recon_loss': recon_loss,
            'val_kl_div': kl_div,
            'pxz': pxz
        }

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        recon_loss = torch.stack([x['val_recon_loss'] for x in outputs]).mean()
        kl_loss = torch.stack([x['val_kl_div'] for x in outputs]).mean()

        logs = {
            'val_elbo_loss': avg_loss,
            'val_recon_loss': recon_loss,
            'val_kl_loss': kl_loss
        }

        self.telegrad_logs['val_loss_epoch'] = avg_loss.item(
        )  # for telegram bot
        self.telegrad_logs['val_recon_loss_epoch'] = recon_loss.item(
        )  # for telegram bot
        self.telegrad_logs['val__kl_loss_epoch'] = kl_loss.item(
        )  # for telegram bot

        return {'avg_val_loss': avg_loss, 'log': logs}

    def test_step(self, batch, batch_idx):
        loss, recon_loss, kl_div, pxz = self._run_step(batch)

        return {
            'test_loss': loss,
            'test_recon_loss': recon_loss,
            'test_kl_div': kl_div,
            'pxz': pxz
        }

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        recon_loss = torch.stack([x['test_recon_loss']
                                  for x in outputs]).mean()
        kl_loss = torch.stack([x['test_kl_div'] for x in outputs]).mean()

        logs = {
            'test_elbo_loss': avg_loss,
            'test_recon_loss': recon_loss,
            'test_kl_loss': kl_loss
        }

        return {'avg_test_loss': avg_loss, 'log': logs}

    def configure_optimizers(self):
        return optimizers[self.opt](self.parameters(),
                                    lr=self.lr,
                                    weight_decay=self.weight_decay)

    def prepare_data(self):
        self.dataloaders.prepare_data()

    def train_dataloader(self):
        return self.dataloaders.train_dataloader(self.batch_size)

    def val_dataloader(self):
        return self.dataloaders.val_dataloader(self.batch_size)

    def test_dataloader(self):
        return self.dataloaders.test_dataloader(self.batch_size)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument(
            '--hidden_dim',
            type=int,
            default=128,
            help=
            'itermediate layers dimension before embedding for default encoder/decoder'
        )
        parser.add_argument('--latent_dim',
                            type=int,
                            default=32,
                            help='dimension of latent variables z')
        parser.add_argument(
            '--input_width',
            type=int,
            default=28,
            help='input image width - 28 for MNIST (must be even)')
        parser.add_argument(
            '--input_height',
            type=int,
            default=28,
            help='input image height - 28 for MNIST (must be even)')
        parser.add_argument('--batch_size',
                            type=int,
                            default=32,
                            help='input vector shape for MNIST')
        # optimizer
        parser.add_argument('--opt',
                            type=str,
                            default='adam',
                            choices=['adam', 'adamax', 'rmsprop'],
                            help='optimizer type for optimization')
        parser.add_argument('--lr',
                            type=float,
                            default=0.001,
                            help='learning rate')
        parser.add_argument('--weight_decay',
                            type=float,
                            default=0,
                            help='weight decay in optimizer')
        return parser
Exemple #2
0
class AutoEncoder(LightningModule):
    """
        Linear Autoencoder
    """
    def __init__(self, hparams=None):
        """
        Linear Autoencoder.
        Parameters
        ----------
        input_shape : int
            Dimension of input vector.
            Example: 784
        latent_shape : int
            Dimension of latent vector.
            Example: 256
        activation : str
            One of 'relu', 'sigmoid' or 'tanh' (the default is 'relu').
        opt : str
            One of 'adam' or 'adamax' or 'rmsprop' (defualt is 'adam')
        batch_size: int
            Batch size for training (default is 32)
        lr: float
            Learning rate for optimizer (default is 0.001)
        weight_decay: float
            Weight decay in optimizer (default is 0)
        """
        super(AutoEncoder, self).__init__()
        self.__check_hparams(hparams)
        self.hparams = hparams

        self.encoder = Encoder(self.input_shape, self.latent_dim,
                               self.activation)
        self.decoder = Decoder(self.input_shape, self.latent_dim,
                               self.activation)

        # NOTE Change dataloaders appropriately
        self.dataloaders = MNISTDataLoaders(save_path=os.getcwd())
        self.telegrad_logs = {
        }  # log everything you want to be reported via telegram here

    def __check_hparams(self, hparams):
        self.input_shape = hparams.input_shape if hasattr(
            hparams, 'input_shape') else 784
        self.latent_dim = hparams.latent_dim if hasattr(hparams,
                                                        'latent_dim') else 256
        self.opt = hparams.opt if hasattr(hparams, 'opt') else 'adam'
        self.batch_size = hparams.batch_size if hasattr(hparams,
                                                        'batch_size') else 32
        self.lr = hparams.lr if hasattr(hparams, 'lr') else 0.001
        self.weight_decay = hparams.weight_decay if hasattr(
            hparams, 'weight_decay') else 0
        self.activation = hparams.activation if hasattr(
            hparams, 'activation') else 'relu'
        self.act = ACTS[self.activation]

    def forward(self, x):
        # NOTE comment the line below, just for testing MNIST
        x = x.view(-1, self.input_shape)
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def configure_optimizers(self):
        """
            Choose Optimizer
        """
        return optimizers[self.opt](self.parameters(),
                                    lr=self.lr,
                                    weight_decay=self.weight_decay)

    def training_step(self, batch, batch_idx):
        """
            Define one training step
        """
        x, y = batch
        # NOTE comment the line below, just for testing MNIST
        x = x.view(-1, self.input_shape)
        x_hat = self(x)  # get predictions from network
        criterion = nn.MSELoss()
        loss = criterion(x_hat, x)
        log = {'trainer_loss': loss}
        return {'loss': loss, 'log': log}

    def training_epoch_end(self, outputs):
        """
            Train Loss at the end of epoch
            Will store logs
        """
        avg_loss = torch.stack([x['trainer_loss'] for x in outputs]).mean()
        logs = {'trainer_loss_epoch': avg_loss}
        self.telegrad_logs['lr'] = self.lr  # for telegram bot
        self.telegrad_logs['trainer_loss_epoch'] = avg_loss.item(
        )  # for telegram bot
        self.logger.log_metrics({'learning_rate':
                                 self.lr})  # if lr is changed by telegram bot
        return {'train_loss': avg_loss, 'log': logs}

    def validation_step(self, batch, batch_idx):
        """
            One validation step
        """
        x, y = batch
        # NOTE comment the line below, just for testing MNIST
        x = x.view(-1, self.input_shape)
        criterion = nn.MSELoss()
        x_hat = self(x)
        return {'val_loss': criterion(x_hat, x)}

    def validation_epoch_end(self, outputs):
        """
            Validatio at the end of epoch
            Will store logs
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val_loss': avg_loss}
        self.telegrad_logs['val_loss_epoch'] = avg_loss.item(
        )  # for telegram bot
        return {'val_loss': avg_loss, 'log': logs}

    def test_step(self, batch, batch_idx):
        x, y = batch
        # NOTE comment the line below, just for testing MNIST
        x = x.view(-1, self.input_shape)
        criterion = nn.MSELoss()
        x_hat = self(x)
        return {'test_loss': criterion(x_hat, x)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': logs}

    def prepare_data(self):
        """
            Prepare the dataset by downloading it 
            Will be run only for the first time if
            dataset is not available
        """
        self.dataloaders.prepare_data()

    def train_dataloader(self):
        """
            Refer dataset.py to make custom dataloaders
        """
        return self.dataloaders.train_dataloader(self.batch_size)

    def val_dataloader(self):
        return self.dataloaders.val_dataloader(self.batch_size)

    def test_dataloader(self):
        return self.dataloaders.test_dataloader(self.batch_size)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--input_shape',
                            type=int,
                            default=784,
                            help='input vector shape for MNIST')
        parser.add_argument('--latent_dim',
                            type=int,
                            default=256,
                            help='latent shape for MNIST')
        parser.add_argument('--activation',
                            type=str,
                            default='relu',
                            choices=['relu', 'sigmoid', 'tanh'],
                            help='activations for nn layers')
        parser.add_argument('--batch_size',
                            type=int,
                            default=32,
                            help='input vector shape for MNIST')
        # optimizer
        parser.add_argument('--opt',
                            type=str,
                            default='adam',
                            choices=['adam', 'adamax', 'rmsprop'],
                            help='optimizer type for optimization')
        parser.add_argument('--lr',
                            type=float,
                            default=0.001,
                            help='learning rate')
        parser.add_argument('--weight_decay',
                            type=float,
                            default=0,
                            help='weight decay in optimizer')
        return parser
Exemple #3
0
class MLP(LightningModule):
    def __init__(self, hparams=None):
        super(MLP, self).__init__()
        """
        Multi-layer perceptron with two layers
        Parameters to be included in hparams
        ----------
        input_shape : int
            Dimension of input vector.
            Default : 784 
        num_outputs : int
            Dimension of output vector.
            Default : 10
        activation : str
            One of 'relu', 'sigmoid' or 'tanh'.
            Default : 'relu'
        opt : str
            One of 'adam' or 'adamax' or 'rmsprop'.
            Default : 'adam'
        batch_size: int
            Batch size for training.
            Default : 32
        lr: float
            Learning rate for optimizer.
            Default : 0.001
        weight_decay: float
            Weight decay in optimizer.
            Default : 0
        """
        self.__check_hparams(hparams)
        self.hparams = hparams

        # NOTE Change dataloaders appropriately
        self.dataloaders = MNISTDataLoaders(save_path=os.getcwd())
        self.telegrad_logs = {} # log everything you want to be reported via telegram here
        
        self.fc = nn.Sequential(
            nn.Linear(self.input_shape, self.hidden_dim[0]),
            self.act(),
            nn.Linear(self.hidden_dim[0], self.hidden_dim[1]),
            self.act(),
            nn.Linear(self.hidden_dim[1], self.num_outputs)
        )
        
    def __check_hparams(self, hparams):
        self.input_shape = hparams.input_shape if hasattr(hparams,'input_shape') else 784
        self.num_outputs = hparams.num_outputs if hasattr(hparams,'num_outputs') else 10
        self.hidden_dim = hparams.hidden_dim if hasattr(hparams,'hidden_dim') else [512,256]
        self.opt = hparams.opt if hasattr(hparams,'opt') else 'adam'
        self.batch_size = hparams.batch_size if hasattr(hparams,'batch_size') else 32
        self.lr = hparams.lr if hasattr(hparams,'lr') else 0.001
        self.weight_decay = hparams.weight_decay if hasattr(hparams,'weight_decay') else 0
        self.activation = hparams.activation if hasattr(hparams,'activation') else 'relu'
        self.act = ACTS[self.activation]

    def forward(self, x):
        # NOTE comment the line below, just for testing purposes
        x = x.view(-1, self.input_shape)
        x = self.fc(x)
        return x

    def configure_optimizers(self):
        """
            Choose Optimizer
        """
        return optimizers[self.opt](self.parameters(), lr=self.lr, weight_decay=self.weight_decay)


    def training_step(self, batch, batch_idx):
        """
            Define one training step
        """
        x, y = batch
        y_hat = self(x)  # get predictions from network
        loss = F.cross_entropy(y_hat, y)
        log = {'trainer_loss':loss}
        # self.logger.experiment.log_metric('train_loss',loss)
        return {'loss': loss, 'log': log}
    
    def training_epoch_end(self, outputs):
        """
            Train Loss at the end of epoch
            Will store logs
        """
        avg_loss = torch.stack([x['trainer_loss'] for x in outputs]).mean()
        logs = {'trainer_loss_epoch': avg_loss}
        self.telegrad_logs['lr'] = self.lr # for telegram bot
        self.telegrad_logs['trainer_loss_epoch'] = avg_loss.item() # for telegram bot
        self.logger.log_metrics({'learning_rate':self.lr}) # if lr is changed by telegram bot
        return {'train_loss': avg_loss, 'log': logs}

    def validation_step(self, batch, batch_idx):
        """
            One validation step
        """
        x, y = batch
        y_hat = self(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        """
            Validation at the end of epoch
            Will store logs
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val_loss_epoch': avg_loss}
        self.telegrad_logs['val_loss_epoch'] = avg_loss.item() # for telegram bot
        return {'val_loss': avg_loss, 'log': logs}
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss_epoch': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': logs}
    
    def prepare_data(self):
        """
            Prepare the dataset by downloading it 
            Will be run only for the first time if
            dataset is not available
        """
        self.dataloaders.prepare_data()

    def train_dataloader(self):
        """
            Refer dataset.py to make custom dataloaders
        """
        return self.dataloaders.train_dataloader(self.batch_size)

    def val_dataloader(self):
        return self.dataloaders.val_dataloader(self.batch_size)
    
    def test_dataloader(self):
        return self.dataloaders.test_dataloader(self.batch_size)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--input_shape', type=int, default=784,
                            help='input vector shape for MNIST')
        parser.add_argument('--num_outputs', type=int, default=10,
                            help='output vector shape for MNIST')
        parser.add_argument('--hidden_dim', type=list, default=[512,256],
                            help='hidden dimensions size')
        parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'sigmoid', 'tanh'],
                            help='activations for nn layers')
        parser.add_argument('--batch_size', type=int, default=32,
                            help='input vector shape for MNIST')
        # optimizer
        parser.add_argument('--opt', type=str, default='adam', choices=['adam', 'adamax', 'rmsprop'],
                            help='optimizer type for optimization')
        parser.add_argument('--lr', type=float, default=0.001,
                            help='learning rate')
        parser.add_argument('--weight_decay', type=float, default=0,
                            help='weight decay in optimizer')
        return parser
Exemple #4
0
class ConvNet(LightningModule):
    def __init__(self, hparams = None):
        super(ConvNet, self).__init__()
        """
        CNN followed by fully connected layers.
        Performs one 2x2 max pool after the first conv.
        Parameters
        ----------
        input_shape : tuple
            Dimension of input image cxlxb.
            Example: (3,210,150)
        num_outputs : int
            Dimension of output.
            Example: 10
        activation : str
            One of 'relu', 'sigmoid' or 'tanh' (the default is 'relu').
        opt : str
            One of 'adam' or 'adamax' or 'rmsprop' (defualt is 'adam')
        batch_size: int
            Batch size for training (default is 32)
        lr: float
            Learning rate for optimizer (default is 0.001)
        weight_decay: float
            Weight decay in optimizer (default is 0)
        """
        self.__check_hparams(hparams)
        self.hparams = hparams

        # NOTE Change dataloaders appropriately
        self.dataloaders = MNISTDataLoaders(save_path=os.getcwd())
        self.telegrad_logs = {} # log everything you want to be reported via telegram here
        
        self.features = nn.Sequential(
            nn.Conv2d(self.input_shape[0], 32, kernel_size=3, stride=4),
            self.act(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            self.act(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            self.act()
        )
        
        self.fc = nn.Sequential(
            nn.Linear(self.feature_size(), 512),
            self.act(),
            nn.Linear(512, self.num_outputs)
        )

    def __check_hparams(self, hparams):
        self.input_shape = hparams.input_shape if hasattr(hparams,'input_shape') else (1,28,28)
        self.num_outputs = hparams.num_outputs if hasattr(hparams,'num_outputs') else 10
        self.opt = hparams.opt if hasattr(hparams,'opt') else 'adam'
        self.batch_size = hparams.batch_size if hasattr(hparams,'batch_size') else 32
        self.lr = hparams.lr if hasattr(hparams,'lr') else 0.001
        self.weight_decay = hparams.weight_decay if hasattr(hparams,'weight_decay') else 0
        self.activation = hparams.activation if hasattr(hparams,'activation') else 'relu'
        self.act = ACTS[self.activation]
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
    def feature_size(self):
        """
            Get feature size after conv layers to flatten
        """
        return self.features(autograd.Variable(torch.zeros(1, *self.input_shape))).view(1, -1).size(1)

    def configure_optimizers(self):
        """
            Choose Optimizer
        """
        return optimizers[self.opt](self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

    def training_step(self, batch, batch_idx):
        """
            Define one training step
        """
        x, y = batch
        y_hat = self(x)  # get predictions from network
        loss = F.cross_entropy(y_hat, y)
        log = {'trainer_loss':loss}
        # self.logger.experiment.add_scalar('loss',loss)
        return {'loss': loss, 'log': log}

    def training_epoch_end(self, outputs):
        """
            Train Loss at the end of epoch
            Will store logs
        """
        avg_loss = torch.stack([x['trainer_loss'] for x in outputs]).mean()
        logs = {'trainer_loss_epoch': avg_loss}
        self.telegrad_logs['lr'] = self.lr # for telegram bot
        self.telegrad_logs['trainer_loss_epoch'] = avg_loss.item() # for telegram bot
        self.logger.log_metrics({'learning_rate':self.lr}) # if lr is changed by telegram bot
        return {'train_loss': avg_loss, 'log': logs}


    def validation_step(self, batch, batch_idx):
        """
            One validation step
        """
        x, y = batch
        y_hat = self(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        """
            Validatio at the end of epoch
            Will store logs
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val_loss_epoch': avg_loss}
        self.telegrad_logs['val_loss_epoch'] = avg_loss.item() # for telegram bot
        return {'val_loss': avg_loss, 'log': logs}

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': logs}

    def prepare_data(self):
        """
            Prepare the dataset by downloading it 
            Will be run only for the first time if
            dataset is not available
        """
        self.dataloaders.prepare_data()

    def train_dataloader(self):
        """
            Refer dataset.py to make custom dataloaders
        """
        return self.dataloaders.train_dataloader(self.batch_size)
    
    def val_dataloader(self):
        return self.dataloaders.val_dataloader(self.batch_size)

    def test_dataloader(self):
        return self.dataloaders.test_dataloader(self.batch_size)
    
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--input_shape', type=list, default=[1,28,28],
                            help='input image shape for MNIST')
        parser.add_argument('--num_outputs', type=int, default=10,
                            help='output vector shape for MNIST')
        parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'sigmoid', 'tanh'],
                            help='activations for nn layers')
        parser.add_argument('--batch_size', type=int, default=32,
                            help='input vector shape for MNIST')
        # optimizer
        parser.add_argument('--opt', type=str, default='adam', choices=['adam', 'adamax', 'rmsprop'],
                            help='optimizer type for optimization')
        parser.add_argument('--lr', type=float, default=0.001,
                            help='learning rate')
        parser.add_argument('--weight_decay', type=float, default=0,
                            help='weight decay in optimizer')
        return parser
Exemple #5
0
class GAN(LightningModule):

    def __init__(self, hparams=None):
        super().__init__()
        self.__check_hparams(hparams)
        self.hparams = hparams

        self.dataloaders = MNISTDataLoaders(save_path=os.getcwd())

        # networks
        self.generator = self.init_generator(self.img_dim)
        self.discriminator = self.init_discriminator(self.img_dim)

        # cache for generated images
        self.generated_imgs = None
        self.last_imgs = None

    def __check_hparams(self, hparams):
        self.input_channels = hparams.input_channels if hasattr(hparams, 'input_channels') else 1
        self.input_width = hparams.input_width if hasattr(hparams, 'input_width') else 28
        self.input_height = hparams.input_height if hasattr(hparams, 'input_height') else 28
        self.latent_dim = hparams.latent_dim if hasattr(hparams, 'latent_dim') else 32
        self.batch_size = hparams.batch_size if hasattr(hparams, 'batch_size') else 32
        self.b1 = hparams.b1 if hasattr(hparams, 'b1') else 0.5
        self.b2 = hparams.b2 if hasattr(hparams, 'b2') else 0.999
        self.learning_rate = hparams.learning_rate if hasattr(hparams, 'learning_rate') else 0.0002
        self.img_dim = (self.input_channels, self.input_width, self.input_height)

    def init_generator(self, img_dim):
        generator = Generator(latent_dim=self.latent_dim, img_shape=img_dim)
        return generator

    def init_discriminator(self, img_dim):
        discriminator = Discriminator(img_shape=img_dim)
        return discriminator

    def forward(self, z):
        """
        Allows infernce to be about generating images
        x = gan(z)
        :param z:
        :return:
        """
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def generator_step(self, x):
        # sample noise
        z = torch.randn(x.shape[0], self.latent_dim)
        z = z.type_as(x)

        # generate images
        self.generated_imgs = self(z)

        # ground truth result (ie: all real)
        real = torch.ones(x.size(0), 1)
        real = real.type_as(x)
        g_loss = self.generator_loss(real)

        tqdm_dict = {'g_loss': g_loss}
        output = OrderedDict({
            'loss': g_loss,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })
        return output

    def generator_loss(self, real):
        # adversarial loss is binary cross-entropy
        g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), real)
        return g_loss

    def discriminator_loss(self, x):
        # how well can it label as real?
        valid = torch.ones(x.size(0), 1)
        valid = valid.type_as(x)

        real_loss = self.adversarial_loss(self.discriminator(x), valid)

        # how well can it label as fake?
        fake = torch.zeros(x.size(0), 1)
        fake = fake.type_as(fake)

        fake_loss = self.adversarial_loss(
            self.discriminator(self.generated_imgs.detach()), fake)

        # discriminator loss is the average of these
        d_loss = (real_loss + fake_loss) / 2
        return d_loss

    def discriminator_step(self, x):
        # Measure discriminator's ability to classify real from generated samples
        d_loss = self.discriminator_loss(x)

        tqdm_dict = {'d_loss': d_loss}
        output = OrderedDict({
            'loss': d_loss,
            'progress_bar': tqdm_dict,
            'log': tqdm_dict
        })
        return output

    def training_step(self, batch, batch_idx, optimizer_idx):
        x, _ = batch
        self.last_imgs = x

        # train generator
        if optimizer_idx == 0:
            return self.generator_step(x)

        # train discriminator
        if optimizer_idx == 1:
            return self.discriminator_step(x)

    def configure_optimizers(self):
        lr = self.learning_rate
        b1 = self.b1
        b2 = self.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def prepare_data(self):
        self.dataloaders.prepare_data()

    def train_dataloader(self):
        return self.dataloaders.train_dataloader(self.batch_size)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--input_width', type=int, default=28,
                            help='input image width - 28 for MNIST (must be even)')
        parser.add_argument('--input_channels', type=int, default=1,
                            help='num channels')
        parser.add_argument('--input_height', type=int, default=28,
                            help='input image height - 28 for MNIST (must be even)')
        parser.add_argument('--learning_rate', type=float, default=0.0002, help="adam: learning rate")
        parser.add_argument('--b1', type=float, default=0.5,
                            help="adam: decay of first order momentum of gradient")
        parser.add_argument('--b2', type=float, default=0.999,
                            help="adam: decay of first order momentum of gradient")
        parser.add_argument('--latent_dim', type=int, default=100,
                            help="generator embedding dim")
        parser.add_argument('--batch_size', type=int, default=64, help="size of the batches")

        return parser
Exemple #6
0
class ConvNet(LightningModule):
    """
    CNN for using 3x3 convs and one final FC layer.
    """
    def __init__(self, hparams = None):
        """CNN followed by fully connected layers.
        Performs one 2x2 max pool after the first conv.
        Parameters to be included in hparams
        ----------
        input_shape : int
            Dimension of input square image.
        channels : list of ints
            List of channels of conv layers including input channels
            (the default is [1,32,32,16,8]).
        filters : list of ints
            List of filter sizes for each of the conv layers
            Length of list should be one less than list of channels
            (the default is [3,3,3,3])
        denses : list of ints
            Sequence of linear layer outputs after the conv layers
            (the default is [10]).
        activation : str
            One of 'relu', 'sigmoid' or 'tanh' (the default is 'relu').
        opt : str
            One of 'adam' or 'adamax' or 'rmsprop'.
            Default : 'adam'
        batch_size: int
            Batch size for training.
            Default : 32
        lr: float
            Learning rate for optimizer.
            Default : 0.001
        weight_decay: float
            Weight decay in optimizer.
            Default : 0
        """
        super().__init__()
        self.__check_hparams(hparams)
        self.hparams = hparams

        # NOTE Change dataloaders appropriately
        self.dataloaders = MNISTDataLoaders(save_path=os.getcwd())
        self.telegrad_logs = {} # log everything you want to be reported via telegram here

        convs = [nn.Conv2d(kernel_size=k, in_channels=in_ch, out_channels=out_ch)
                 for in_ch, out_ch, k in zip(self.channels[:-1], self.channels[1:], self.filters)]

        if len(self.channels) <= 1:
            self.conv_net = None
            feature_count = self.input_shape*self.input_shape
        else:
            self.conv_net = nn.Sequential(
                convs[0],
                nn.MaxPool2d(kernel_size=2),
                self.act(),
                *[layer for tup in zip(convs[1:], [self.act() for _ in convs[1:]]) for layer in tup]
            )

            with torch.no_grad():
                test_inp = torch.randn(1, 1, self.input_shape, self.input_shape)
                features = self.conv_net(test_inp)
                feature_count = features.view(-1).shape[0]

        linears = [nn.Linear(in_f, out_f) for in_f, out_f in
                   zip([feature_count]+self.denses[:-1], self.denses)]

        self.dense = nn.Sequential(
            *[layer for tup in zip(linears, [self.act() for _ in linears]) for layer in tup][:-1]
        )

    def __check_hparams(self, hparams):
        self.input_shape = hparams.input_shape if hasattr(hparams,'input_shape') else 28
        self.channels = hparams.channels if hasattr(hparams, 'channels') else [1, 32, 32, 16, 8]
        self.filters = hparams.filters if hasattr(hparams, 'filters') else [3, 3, 3, 3]
        self.denses = hparams.denses if hasattr(hparams, 'denses') else [10]
        self.activation = hparams.activation if hasattr(hparams,'activation') else 'relu'
        self.opt = hparams.opt if hasattr(hparams,'opt') else 'adam'
        self.batch_size = hparams.batch_size if hasattr(hparams,'batch_size') else 32
        self.lr = hparams.lr if hasattr(hparams,'lr') else 0.001
        self.weight_decay = hparams.weight_decay if hasattr(hparams,'weight_decay') else 0
        self.act = ACTS[self.activation]

    def forward(self, input):
        if self.conv_net:
            input = self.conv_net(input)
        out = self.dense(input.view(input.shape[0], -1))
        return out

    def configure_optimizers(self):
        return optimizers[self.opt](self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)  # get predictions from network
        loss = F.cross_entropy(y_hat, y)
        log = {'trainer_loss':loss}
        # self.logger.experiment.add_scalar('loss',loss)
        return {'loss': loss, 'log': log}

    def training_epoch_end(self, outputs):
        """
            Train Loss at the end of epoch
            Will store logs
        """
        avg_loss = torch.stack([x['trainer_loss'] for x in outputs]).mean()
        logs = {'trainer_loss_epoch': avg_loss}
        self.telegrad_logs['lr'] = self.lr # for telegram bot
        self.telegrad_logs['trainer_loss_epoch'] = avg_loss.item() # for telegram bot
        self.logger.log_metrics({'learning_rate':self.lr}) # if lr is changed by telegram bot
        return {'train_loss': avg_loss, 'log': logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        logs = {'val_loss_epoch': avg_loss}
        self.telegrad_logs['val_loss_epoch'] = avg_loss.item() # for telegram bot
        return {'val_loss': avg_loss, 'log': logs}

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss_epoch': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': logs}

    def prepare_data(self):
        # download only
        self.dataloaders.prepare_data()

    def train_dataloader(self):
        return self.dataloaders.train_dataloader(self.batch_size)

    def val_dataloader(self):
        return self.dataloaders.val_dataloader(self.batch_size)

    def test_dataloader(self):
        return self.dataloaders.test_dataloader(self.batch_size)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--input_shape', type=int, default=28,
                            help='input image dim for MNIST (must be square image)')
        parser.add_argument('--channels', type=list, default=[1, 32, 32, 16, 8],
                            help='List of channels in each conv layer including input')
        parser.add_argument('--filters', type=list, default=[3, 3, 3, 3],
                            help='List of filter sizes for each of the conv layers. Length of list should be one less than list of channels')
        parser.add_argument('--denses', type=list, default=[10],
                            help='List of linear layer outputs after the conv layers')
        parser.add_argument('--activation', type=str, default='relu', choices=['relu', 'sigmoid', 'tanh'],
                            help='activations for nn layers')
        parser.add_argument('--batch_size', type=int, default=32,
                            help='input vector shape for MNIST')
        # optimizer
        parser.add_argument('--opt', type=str, default='adam', choices=['adam', 'adamax', 'rmsprop'],
                            help='optimizer type for optimization')
        parser.add_argument('--lr', type=float, default=0.001,
                            help='learning rate')
        parser.add_argument('--weight_decay', type=float, default=0,
                            help='weight decay in optimizer')
        return parser