Esempio n. 1
0
    def __init__(self, cfg):
        super().__init__()
        if isinstance(cfg, Box):
            raise ValueError('Pass a dict instead')

        self.save_hyperparameters('cfg')
        self.cfg = Box(cfg)
        self.learning_rate = self.cfg.network.lr

        # encoder: aka encode the inputs
        self.encoder = resnet18_encoder(False, False)

        # decoder: aka decode the latent variable
        # Not that the decoder do not generate the sample
        # It generates the parameters of the distributions
        # we use to generate the sample. - this is p_theta(x|z)
        self.decoder = resnet18_decoder(
            latent_dim=self.cfg.network.latent_dim,
            input_height=self.cfg.network.input_height,
            first_conv=False,
            maxpool1=False)

        # Now we need to generate the distributions parameters
        self.fc_mu = nn.Linear(self.cfg.network.enc_out_dim,
                               self.cfg.network.latent_dim)
        self.fc_var = nn.Linear(self.cfg.network.enc_out_dim,
                                self.cfg.network.latent_dim)

        # Additioal parameter for the variance of p_theta from
        # the encoder
        self.log_p_xz_std = nn.Parameter(torch.Tensor([0.0]))
Esempio n. 2
0
    def __init__(self, beta=4, enc_out_dim=512, latent_dim=256, input_height=128, device='cpu'):
        super().__init__()

        self.beta = beta
        self.latent_dim = latent_dim
        self.device = device
        self.input_height = input_height
        # encoder, decoder
        self.encoder = resnet18_encoder(False, False)
        self.decoder = nn.Sequential(
            resnet18_decoder(
            latent_dim=latent_dim, 
            input_height=input_height, 
            first_conv=False, 
            maxpool1=False), 
            nn.Tanh()   # Tanh activation to clamp values to [-1, 1] of the input
        )

        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

        self.p = 0.2
    def __init__(self,
                 input_height,
                 enc_type='resnet18',
                 first_conv=False,
                 maxpool1=False,
                 enc_out_dim=512,
                 kl_coeff=0.1,
                 latent_dim=256,
                 lr=1e-4,
                 **kwargs):
        """
        Args:
            input_height: height of the images
            enc_type: option between resnet18 or resnet50
            first_conv: use standard kernel_size 7, stride 2 at start or
                replace it with kernel_size 3, stride 1 conv
            maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
            enc_out_dim: set according to the out_channel count of
                encoder used (512 for resnet18, 2048 for resnet50)
            kl_coeff: coefficient for kl term of the loss
            latent_dim: dim of latent space
            lr: learning rate for Adam
        """

        super(VAE, self).__init__()

        self.save_hyperparameters()

        self.lr = lr
        self.kl_coeff = kl_coeff
        self.enc_out_dim = enc_out_dim
        self.latent_dim = latent_dim
        self.input_height = input_height

        valid_encoders = {
            'resnet18': {
                'enc': resnet18_encoder,
                'dec': resnet18_decoder
            },
            'resnet50': {
                'enc': resnet50_encoder,
                'dec': resnet50_decoder
            },
        }

        if enc_type not in valid_encoders:
            self.encoder = resnet18_encoder(first_conv, maxpool1)
            self.decoder = resnet18_decoder(self.latent_dim, self.input_height,
                                            first_conv, maxpool1)
        else:
            self.encoder = valid_encoders[enc_type]['enc'](first_conv,
                                                           maxpool1)
            self.decoder = valid_encoders[enc_type]['dec'](self.latent_dim,
                                                           self.input_height,
                                                           first_conv,
                                                           maxpool1)

        self.fc_mu = nn.Linear(self.enc_out_dim, self.latent_dim)
        self.fc_var = nn.Linear(self.enc_out_dim, self.latent_dim)
Esempio n. 4
0
    def __init__(
        self,
        input_height: int,
        enc_type: str = "resnet18",
        first_conv: bool = False,
        maxpool1: bool = False,
        enc_out_dim: int = 512,
        latent_dim: int = 256,
        lr: float = 1e-4,
        **kwargs,
    ):
        """
        Args:
            input_height: height of the images
            enc_type: option between resnet18 or resnet50
            first_conv: use standard kernel_size 7, stride 2 at start or
                replace it with kernel_size 3, stride 1 conv
            maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
            enc_out_dim: set according to the out_channel count of
                encoder used (512 for resnet18, 2048 for resnet50)
            latent_dim: dim of latent space
            lr: learning rate for Adam
        """

        super().__init__()

        self.save_hyperparameters()

        self.lr = lr
        self.enc_out_dim = enc_out_dim
        self.latent_dim = latent_dim
        self.input_height = input_height

        valid_encoders = {
            "resnet18": {
                "enc": resnet18_encoder,
                "dec": resnet18_decoder,
            },
            "resnet50": {
                "enc": resnet50_encoder,
                "dec": resnet50_decoder,
            },
        }

        if enc_type not in valid_encoders:
            self.encoder = resnet18_encoder(first_conv, maxpool1)
            self.decoder = resnet18_decoder(self.latent_dim, self.input_height,
                                            first_conv, maxpool1)
        else:
            self.encoder = valid_encoders[enc_type]["enc"](first_conv,
                                                           maxpool1)
            self.decoder = valid_encoders[enc_type]["dec"](self.latent_dim,
                                                           self.input_height,
                                                           first_conv,
                                                           maxpool1)

        self.fc = nn.Linear(self.enc_out_dim, self.latent_dim)
Esempio n. 5
0
    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super().__init__()

        self.save_hyperparameters()

        # encoder, decoder
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(latent_dim=latent_dim,
                                        input_height=input_height,
                                        first_conv=False,
                                        maxpool1=False)

        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)

        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))
Esempio n. 6
0
    def __init__(self,
                 enc_out_dim=512,
                 latent_dim=256,
                 load_path=None,
                 device='cpu'):
        '''
        Identical to the VAE module in RL_VAE/vae.py, but wihtout the decoder part, for use in RL algorithms
        '''
        super().__init__()
        self.device = device
        self.latent_dim = latent_dim
        # encoder
        self.encoder = resnet18_encoder(False, False)
        # distribution parameters
        self.fc_mu = nn.Linear(enc_out_dim, latent_dim)
        self.fc_var = nn.Linear(enc_out_dim, latent_dim)

        if load_path is not None:
            self.load_weights(load_path)
    def __init__(self,
                 input_height,
                 enc_type='resnet18',
                 first_conv=False,
                 maxpool1=False,
                 enc_out_dim=512,
                 kl_coeff=0.1,
                 latent_dim=256,
                 lr=1e-4,
                 **kwargs):
        """
        Standard AE

        Model is available pretrained on different datasets:

        Example::

            # not pretrained
            ae = AE()

            # pretrained on imagenet
            ae = AE.from_pretrained('resnet50-imagenet')

            # pretrained on cifar10
            ae = AE.from_pretrained('resnet18-cifar10')

        Args:

            input_height: height of the images
            enc_type: option between resnet18 or resnet50
            first_conv: use standard kernel_size 7, stride 2 at start or
                replace it with kernel_size 3, stride 1 conv
            maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
            enc_out_dim: set according to the out_channel count of
                encoder used (512 for resnet18, 2048 for resnet50)
            latent_dim: dim of latent space
            lr: learning rate for Adam
        """

        super(AE, self).__init__()

        self.save_hyperparameters()

        self.lr = lr
        self.enc_out_dim = enc_out_dim
        self.latent_dim = latent_dim
        self.input_height = input_height

        valid_encoders = {
            'resnet18': {
                'enc': resnet18_encoder,
                'dec': resnet18_decoder
            },
            'resnet50': {
                'enc': resnet50_encoder,
                'dec': resnet50_decoder
            },
        }

        if enc_type not in valid_encoders:
            self.encoder = resnet18_encoder(first_conv, maxpool1)
            self.decoder = resnet18_decoder(self.latent_dim, self.input_height,
                                            first_conv, maxpool1)
        else:
            self.encoder = valid_encoders[enc_type]['enc'](first_conv,
                                                           maxpool1)
            self.decoder = valid_encoders[enc_type]['dec'](self.latent_dim,
                                                           self.input_height,
                                                           first_conv,
                                                           maxpool1)

        self.fc = nn.Linear(self.enc_out_dim, self.latent_dim)
Esempio n. 8
0
    def __init__(self, layer_idx, z_dim, im_size):
        super(VAE, self).__init__()

        self.z_dim, self.layer_idx = z_dim, layer_idx

        # We require different network architectures for different image sizes of the datasets (MNIST, UCSD, MVTEC)

        # Encoder/decoder for MNIST
        if im_size[1:] == (28, 28):
            # This part defines the encoder part of the VAE on images of (1*28*28)
            self.enc_main = nn.Sequential(
                nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), # 28 - 14
                nn.ReLU(),
                nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 14 - 7
                nn.ReLU(),
                nn.Flatten(), # 7*7*128 = 6272
                nn.Linear(6272, 1024),
                nn.ReLU(),
            )
            # This part defines the decoder part of the VAE on images of (1*28*28)
            self.dec_vae = nn.Sequential(
                nn.Linear(self.z_dim, 1024),
                nn.ReLU(),
                nn.Linear(1024, 6272),
                nn.ReLU(),
                Reshape((128, 7, 7)),
                nn.ReLU(),
                nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
                nn.ReLU(),
                nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
                nn.Sigmoid()
            )
            
            # ResNet encoder is saved in self.encoder, but only used for MVTEC dataset. Setting this to None
            # allows us to check which encoder/decoder to use as they differ slightly in input and output
            self.encoder = None

            encoder_output_dim = 1024

        # Encoder/decoder for UCSD
        elif im_size[1:] == (100, 100):
            # This part defines the encoder part of the VAE on images of (1*100*100)
            self.enc_main = nn.Sequential(
                nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1), # (100 * 100) - (50 * 50)
                nn.ReLU(),
                nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # (50 * 50) - (25 * 25)
                nn.ReLU(),
                nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # (25 * 25) - (12 * 12)
                nn.ReLU(),
                nn.Flatten(), # 12*12*256 = 36864
                nn.Linear(36864, 1024),
                nn.ReLU(),
            )
            # This part defines the decoder part of the VAE on images of (1*100*100)
            self.dec_vae = nn.Sequential(
                nn.Linear(self.z_dim, 1024),
                nn.ReLU(),
                nn.Linear(1024, 36864),
                nn.ReLU(),
                Reshape((256, 12, 12)),
                nn.ReLU(),
                nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=1),
                nn.ReLU(),
                nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
                nn.ReLU(),
                nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
                nn.Sigmoid()
            )
            
            # ResNet encoder is saved in self.encoder, but only used for MVTEC dataset. Setting this to None
            # allows us to check which encoder/decoder to use as they differ slightly in input and output
            self.encoder = None

            encoder_output_dim = 1024
            
        # Encoder/decoder for MVTEC color images
        elif im_size == (3, 256, 256):
            self.encoder = resnet18_encoder(first_conv=True, maxpool1=True)
            self.decoder = resnet18_decoder(self.z_dim, im_size[1], first_conv=True, maxpool1=True)

            encoder_output_dim = 512
        
        # Encoder/decoder for MVTEC grayscale images
        elif im_size == (1, 256, 256):
            self.encoder = resnet18_encoder(first_conv=True, maxpool1=True)
            self.decoder = resnet18_decoder(self.z_dim, im_size[1], first_conv=True, maxpool1=True)

            # Change encoder first conv and decoder last conv to deal with grayscale instead of color
            self.encoder.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
            self.decoder.conv1 = nn.Conv2d(64, 1, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)

            encoder_output_dim = 512

        # Final part of the encoder network, computing mu and log_var
        self.mu = nn.Linear(encoder_output_dim, self.z_dim)
        self.var = nn.Linear(encoder_output_dim, self.z_dim)

        # Norm mu and log_var are only used when we enable "normal_diff" inference
        self.norm_mu = None
        self.norm_log_var = None

        # Create forward and backward functions used to save activations and gradients
        def save_layer_activations(self, input, output):
            input[0].requires_grad_()

            # Save the activation
            self.layer_activ_out = output
            # Create function which will save the gradient during backward pass
            def save_layer_grads(g_output):
                self.layer_grad_out = g_output
            # Register hook to the activation
            self.layer_activ_out.requires_grad_(True)
            self.layer_activ_out.register_hook(save_layer_grads)

        # Register these hooks to the correct module in the network (given by layer_idx)
        if self.encoder is not None:
            conv_counter = 0
            for name, module in self.encoder.named_modules():
                if 'conv' in name:
                    if conv_counter == self.layer_idx:
                        module.register_forward_hook(save_layer_activations)
                    conv_counter += 1
        else:
            self.enc_main[layer_idx].register_forward_hook(save_layer_activations)
Esempio n. 9
0
 def __init__(self, opt):
     super().__init__(opt)
     self.encoder = resnet18_encoder(False, False)