Example #1
0
    def __init__(
            self,
            input_dim=2,
            h_dim=64,
            z_dim=2,
            nonlinearity='tanh',
            num_hidden_layers=1,
            init='gaussian',  #None,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.nonlinearity = nonlinearity
        self.num_hidden_layers = num_hidden_layers
        self.init = init

        self.main = MLP(input_dim=z_dim,
                        hidden_dim=h_dim,
                        output_dim=h_dim,
                        nonlinearity=nonlinearity,
                        num_hidden_layers=num_hidden_layers - 1,
                        use_nonlinearity_output=True)
        self.reparam = NormalDistributionLinear(h_dim, input_dim)

        if self.init == 'gaussian':
            self.reset_parameters()
        else:
            pass
Example #2
0
    def __init__(
            self,
            z_dim=32,
            c_dim=450,
            z0_dim=100,
            act=nn.ELU(),
            do_center=False,
    ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.z0_dim = z0_dim
        self.do_center = do_center

        #self.enc = nn.Sequential(
        #    nn_.ResConv2d(1,16,3,2,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(16,16,3,1,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(16,32,3,2,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(32,32,3,1,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(32,32,3,2,padding=1,activation=act),
        #    act,
        #    nn_.Reshape((-1,32*4*4)),
        #    nn_.ResLinear(32*4*4,c_dim),
        #    act
        #)
        self.fc = nn.Sequential(
            nn.Linear(c_dim + z_dim, c_dim, bias=True),
            act,
        )
        self.reparam = NormalDistributionLinear(c_dim, z0_dim)
Example #3
0
    def __init__(
        self,
        input_height=28,
        input_channels=1,
        z0_dim=100,
        z_dim=32,
        nonlinearity='softplus',
    ):
        super().__init__()
        self.input_height = input_height
        self.input_channels = input_channels
        self.z0_dim = z0_dim
        self.z_dim = z_dim
        self.nonlinearity = nonlinearity

        s_h = input_height
        s_h2 = conv_out_size(s_h, 5, 2, 2)
        s_h4 = conv_out_size(s_h2, 5, 2, 2)
        s_h8 = conv_out_size(s_h4, 5, 2, 2)
        #print(s_h, s_h2, s_h4, s_h8)

        self.afun = get_nonlinear_func(nonlinearity)
        self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True)
        self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True)
        self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True)
        self.fc = nn.Linear(s_h8 * s_h8 * 32 + z0_dim, 800, bias=True)
        self.reparam = NormalDistributionLinear(800, z_dim)
Example #4
0
    def __init__(
        self,
        input_dim=2,
        h_dim=8,
        noise_dim=2,
        nonlinearity='softplus',
        num_hidden_layers=1,
        clip_logvar=None,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.h_dim = h_dim
        self.noise_dim = noise_dim
        self.nonlinearity = nonlinearity
        self.num_hidden_layers = num_hidden_layers
        self.clip_logvar = clip_logvar

        self.main = MLP(input_dim=input_dim,
                        hidden_dim=h_dim,
                        output_dim=h_dim,
                        nonlinearity=nonlinearity,
                        num_hidden_layers=num_hidden_layers - 1,
                        use_nonlinearity_output=True)
        self.reparam = NormalDistributionLinear(h_dim,
                                                noise_dim,
                                                nonlinearity=clip_logvar)
Example #5
0
    def __init__(
            self,
            z0_dim=32,
            c_dim=450,
            act=nn.ELU(),
            do_center=False,
            clip_logvar=None,
    ):
        super().__init__()
        self.z0_dim = z0_dim
        self.c_dim = c_dim
        self.do_center = do_center
        self.clip_logvar = clip_logvar

        #self.enc = nn.Sequential(
        #    nn_.ResConv2d(1,16,3,2,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(16,16,3,1,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(16,32,3,2,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(32,32,3,1,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(32,32,3,2,padding=1,activation=act),
        #    act,
        #    nn_.Reshape((-1,32*4*4)),
        #    nn_.ResLinear(32*4*4,c_dim),
        #    act
        #)
        self.reparam = NormalDistributionLinear(c_dim,
                                                z0_dim,
                                                nonlinearity=clip_logvar)
Example #6
0
    def __init__(
        self,
        input_dim=2,
        noise_dim=2,
        h_dim=64,
        z_dim=2,
        nonlinearity='tanh',
        num_hidden_layers=1,
        enc_input=False,
        enc_noise=False,
        clip_logvar=None,
    ):
        super().__init__(
            input_dim=input_dim,
            noise_dim=noise_dim,
            h_dim=h_dim,
            z_dim=z_dim,
            nonlinearity=nonlinearity,
            num_hidden_layers=num_hidden_layers,
            enc_input=enc_input,
            enc_noise=enc_noise,
            clip_logvar=clip_logvar,
        )
        inp_dim = input_dim if not enc_input else h_dim
        ctx_dim = noise_dim if not enc_noise else h_dim

        self.inp_encode = Identity() if not enc_input \
                else MLP(input_dim=input_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True)
        self.nos_encode = Identity() if not enc_noise \
                else MLP(input_dim=noise_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True)
        self.fc = MLP(input_dim=inp_dim + ctx_dim,
                      hidden_dim=h_dim,
                      output_dim=h_dim,
                      nonlinearity=nonlinearity,
                      num_hidden_layers=num_hidden_layers - 1,
                      use_nonlinearity_output=True)
        self.reparam = NormalDistributionLinear(h_dim,
                                                z_dim,
                                                nonlinearity=clip_logvar)
Example #7
0
    def __init__(
        self,
        input_dim=784,
        h_dim=300,
        z_dim=32,
        nonlinearity='softplus',
        num_hidden_layers=2,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.nonlinearity = nonlinearity
        self.num_hidden_layers = num_hidden_layers

        self.main = MLP(input_dim=input_dim,
                        hidden_dim=h_dim,
                        output_dim=h_dim,
                        nonlinearity=nonlinearity,
                        num_hidden_layers=num_hidden_layers - 1,
                        use_nonlinearity_output=True)
        self.reparam = NormalDistributionLinear(h_dim, z_dim)
Example #8
0
class Encoder(nn.Module):
    def __init__(
        self,
        input_height=28,
        input_channels=1,
        z_dim=32,
        nonlinearity='softplus',
    ):
        super().__init__()
        self.input_height = input_height
        self.input_channels = input_channels
        self.z_dim = z_dim
        self.nonlinearity = nonlinearity

        s_h = input_height
        s_h2 = conv_out_size(s_h, 5, 2, 2)
        s_h4 = conv_out_size(s_h2, 5, 2, 2)
        s_h8 = conv_out_size(s_h4, 5, 2, 2)
        #print(s_h, s_h2, s_h4, s_h8)
        #ipdb.set_trace()

        self.afun = get_nonlinear_func(nonlinearity)
        self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True)
        self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True)
        self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True)
        self.fc = nn.Linear(s_h8 * s_h8 * 32, 800, bias=True)
        self.reparam = NormalDistributionLinear(800, z_dim)

    def sample(self, mu, logvar):
        return self.reparam.sample_gaussian(mu, logvar)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, self.input_channels, self.input_height,
                   self.input_height)

        # rescale
        x = 2 * x - 1

        # forward
        h1 = self.afun(self.conv1(x))
        h2 = self.afun(self.conv2(h1))
        h3 = self.afun(self.conv3(h2))
        h3 = h3.view(batch_size, -1)
        h4 = self.afun(self.fc(h3))
        mu, logvar = self.reparam(h4)

        # sample
        z = self.sample(mu, logvar)

        return z, mu, logvar
Example #9
0
class SimpleEncoder(Encoder):
    def __init__(
        self,
        input_dim=2,
        noise_dim=2,
        h_dim=64,
        z_dim=2,
        nonlinearity='tanh',
        num_hidden_layers=1,
        enc_input=False,
        enc_noise=False,
        clip_logvar=None,
    ):
        super().__init__(
            input_dim=input_dim,
            noise_dim=noise_dim,
            h_dim=h_dim,
            z_dim=z_dim,
            nonlinearity=nonlinearity,
            num_hidden_layers=num_hidden_layers,
            enc_input=enc_input,
            enc_noise=enc_noise,
            clip_logvar=clip_logvar,
        )
        inp_dim = input_dim if not enc_input else h_dim
        ctx_dim = noise_dim if not enc_noise else h_dim

        self.inp_encode = Identity() if not enc_input \
                else MLP(input_dim=input_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True)
        self.nos_encode = Identity() if not enc_noise \
                else MLP(input_dim=noise_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True)
        self.fc = MLP(input_dim=inp_dim + ctx_dim,
                      hidden_dim=h_dim,
                      output_dim=h_dim,
                      nonlinearity=nonlinearity,
                      num_hidden_layers=num_hidden_layers - 1,
                      use_nonlinearity_output=True)
        self.reparam = NormalDistributionLinear(h_dim,
                                                z_dim,
                                                nonlinearity=clip_logvar)

    def sample(self, mu_z, logvar_z):
        return self.reparam.sample_gaussian(mu_z, logvar_z)

    def _forward_all(self, inp, nos):
        h1 = torch.cat([inp, nos], dim=1)
        h2 = self.fc(h1)
        mu_z, logvar_z = self.reparam(h2)
        z = self.sample(mu_z, logvar_z)
        return z, mu_z, logvar_z, h2
Example #10
0
    def __init__(
        self,
        input_dim=2,
        z_dim=2,
        noise_dim=2,
        h_dim=64,
        nonlinearity='tanh',
        num_hidden_layers=1,
        enc_input=False,
        enc_latent=False,
        clip_logvar=None,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.z_dim = z_dim
        self.noise_dim = noise_dim
        self.h_dim = h_dim
        self.nonlinearity = nonlinearity
        self.num_hidden_layers = num_hidden_layers
        self.enc_input = enc_input
        self.enc_latent = enc_latent
        inp_dim = input_dim if not enc_input else h_dim
        ltt_dim = z_dim if not enc_latent else h_dim

        self.inp_encode = Identity() if not enc_input \
                else MLP(input_dim=input_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True)
        self.ltt_encode = Identity() if not enc_latent \
                else MLP(input_dim=z_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True)
        self.fc = MLP(input_dim=inp_dim + ltt_dim,
                      hidden_dim=h_dim,
                      output_dim=h_dim,
                      nonlinearity=nonlinearity,
                      num_hidden_layers=num_hidden_layers - 1,
                      use_nonlinearity_output=True)
        self.reparam = NormalDistributionLinear(h_dim,
                                                noise_dim,
                                                nonlinearity=clip_logvar)
Example #11
0
class Decoder(nn.Module):
    def __init__(
            self,
            input_dim=2,
            h_dim=64,
            z_dim=2,
            nonlinearity='tanh',
            num_hidden_layers=1,
            init='gaussian',  #None,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.nonlinearity = nonlinearity
        self.num_hidden_layers = num_hidden_layers
        self.init = init

        self.main = MLP(input_dim=z_dim,
                        hidden_dim=h_dim,
                        output_dim=h_dim,
                        nonlinearity=nonlinearity,
                        num_hidden_layers=num_hidden_layers - 1,
                        use_nonlinearity_output=True)
        self.reparam = NormalDistributionLinear(h_dim, input_dim)

        if self.init == 'gaussian':
            self.reset_parameters()
        else:
            pass

    def reset_parameters(self):
        nn.init.normal_(self.reparam.mean_fn.weight)

    def sample(self, mu, logvar):
        return self.reparam.sample_gaussian(mu, logvar)

    def forward(self, z):
        batch_size = z.size(0)
        z = z.view(batch_size, -1)

        # forward
        h = self.main(z)
        mu, logvar = self.reparam(h)

        # sample
        x = self.sample(mu, logvar)

        return x, mu, logvar
Example #12
0
class Encoder(nn.Module):
    def __init__(
            self,
            z_dim=32,
            c_dim=450,
            act=nn.ELU(),
            do_center=False,
    ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.do_center = do_center

        self.enc = nn.Sequential(
            nn_.ResConv2d(1, 16, 3, 2, padding=1, activation=act), act,
            nn_.ResConv2d(16, 16, 3, 1, padding=1, activation=act), act,
            nn_.ResConv2d(16, 32, 3, 2, padding=1, activation=act), act,
            nn_.ResConv2d(32, 32, 3, 1, padding=1, activation=act), act,
            nn_.ResConv2d(32, 32, 3, 2, padding=1, activation=act), act,
            nn_.Reshape((-1, 32 * 4 * 4)), nn_.ResLinear(32 * 4 * 4, c_dim),
            act)
        self.reparam = NormalDistributionLinear(c_dim, z_dim)

    def sample(self, mu, logvar):
        return self.reparam.sample_gaussian(mu, logvar)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, 1, 28, 28)

        # rescale
        if self.do_center:
            x = 2 * x - 1

        # enc
        ctx = self.enc(x)
        mu, logvar = self.reparam(ctx)

        # sample
        z = self.sample(mu, logvar)

        return z, mu, logvar
Example #13
0
class Encoder(nn.Module):
    def __init__(
        self,
        input_dim=784,
        h_dim=300,
        z_dim=32,
        nonlinearity='softplus',
        num_hidden_layers=2,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.nonlinearity = nonlinearity
        self.num_hidden_layers = num_hidden_layers

        self.main = MLP(input_dim=input_dim,
                        hidden_dim=h_dim,
                        output_dim=h_dim,
                        nonlinearity=nonlinearity,
                        num_hidden_layers=num_hidden_layers - 1,
                        use_nonlinearity_output=True)
        self.reparam = NormalDistributionLinear(h_dim, z_dim)

    def sample(self, mu, logvar):
        return self.reparam.sample_gaussian(mu, logvar)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, self.input_dim)

        # rescale
        x = 2 * x - 1

        # forward
        h = self.main(x)
        mu, logvar = self.reparam(h)

        # sample
        z = self.sample(mu, logvar)

        return z, mu, logvar
Example #14
0
class Encoder(nn.Module):
    def __init__(
        self,
        input_height=28,
        input_channels=1,
        z0_dim=100,
        z_dim=32,
        nonlinearity='softplus',
    ):
        super().__init__()
        self.input_height = input_height
        self.input_channels = input_channels
        self.z0_dim = z0_dim
        self.z_dim = z_dim
        self.nonlinearity = nonlinearity

        s_h = input_height
        s_h2 = conv_out_size(s_h, 5, 2, 2)
        s_h4 = conv_out_size(s_h2, 5, 2, 2)
        s_h8 = conv_out_size(s_h4, 5, 2, 2)
        #print(s_h, s_h2, s_h4, s_h8)

        self.afun = get_nonlinear_func(nonlinearity)
        self.conv1 = nn.Conv2d(self.input_channels, 16, 5, 2, 2, bias=True)
        self.conv2 = nn.Conv2d(16, 32, 5, 2, 2, bias=True)
        self.conv3 = nn.Conv2d(32, 32, 5, 2, 2, bias=True)
        self.fc = nn.Linear(s_h8 * s_h8 * 32 + z0_dim, 800, bias=True)
        self.reparam = NormalDistributionLinear(800, z_dim)

    def sample(self, mu_z, logvar_z):
        return self.reparam.sample_gaussian(mu_z, logvar_z)

    def forward(self, x, z0, nz=1):
        batch_size = x.size(0)
        x = x.view(batch_size, self.input_channels, self.input_height,
                   self.input_height)
        assert z0.size(0) == batch_size * nz

        # rescale
        x = 2 * x - 1

        # forward
        h1 = self.afun(self.conv1(x))
        h2 = self.afun(self.conv2(h1))
        h3 = self.afun(self.conv3(h2))
        h3 = h3.view(batch_size, -1)

        # view
        h3 = h3.unsqueeze(1).expand(-1, nz, -1).contiguous()
        h3 = h3.view(batch_size * nz, -1)

        # concat
        h3z0 = torch.cat([h3, z0], dim=1)

        # forward
        h4 = self.afun(self.fc(h3z0))
        mu, logvar = self.reparam(h4)

        # sample
        z = self.sample(mu, logvar)

        return z, mu, logvar, h4
Example #15
0
class AuxDecoder(nn.Module):
    def __init__(
        self,
        input_dim=2,
        z_dim=2,
        noise_dim=2,
        h_dim=64,
        nonlinearity='tanh',
        num_hidden_layers=1,
        enc_input=False,
        enc_latent=False,
        clip_logvar=None,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.z_dim = z_dim
        self.noise_dim = noise_dim
        self.h_dim = h_dim
        self.nonlinearity = nonlinearity
        self.num_hidden_layers = num_hidden_layers
        self.enc_input = enc_input
        self.enc_latent = enc_latent
        inp_dim = input_dim if not enc_input else h_dim
        ltt_dim = z_dim if not enc_latent else h_dim

        self.inp_encode = Identity() if not enc_input \
                else MLP(input_dim=input_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True)
        self.ltt_encode = Identity() if not enc_latent \
                else MLP(input_dim=z_dim, hidden_dim=h_dim, output_dim=h_dim, nonlinearity=nonlinearity, num_hidden_layers=num_hidden_layers-1, use_nonlinearity_output=True)
        self.fc = MLP(input_dim=inp_dim + ltt_dim,
                      hidden_dim=h_dim,
                      output_dim=h_dim,
                      nonlinearity=nonlinearity,
                      num_hidden_layers=num_hidden_layers - 1,
                      use_nonlinearity_output=True)
        self.reparam = NormalDistributionLinear(h_dim,
                                                noise_dim,
                                                nonlinearity=clip_logvar)

    def sample(self, mu, logvar):
        return self.reparam.sample_gaussian(mu, logvar)

    def _forward_inp(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, self.input_dim)

        # enc
        inp = self.inp_encode(x)

        return inp

    def _forward_ltt(self, z):
        # enc
        ltt = self.ltt_encode(z)

        return ltt

    def _forward_all(self, inp, ltt):
        h1 = torch.cat([inp, ltt], dim=1)
        h2 = self.fc(h1)
        mu_n, logvar_n = self.reparam(h2)
        noise = self.sample(mu_n, logvar_n)
        return noise, mu_n, logvar_n

    def forward(self, x, z, nz=1):
        batch_size = x.size(0)

        # enc
        ltt = self._forward_ltt(z)
        inp = self._forward_inp(x)

        # view
        assert ltt.size(0) == batch_size * nz
        inp = inp.unsqueeze(1).expand(-1, nz, -1).contiguous()
        inp = inp.view(batch_size * nz, -1)

        # forward
        noise, mu_n, logvar_n = self._forward_all(inp, ltt)

        return noise, mu_n, logvar_n
Example #16
0
class AuxDecoder(nn.Module):
    def __init__(
            self,
            z_dim=32,
            c_dim=450,
            z0_dim=100,
            act=nn.ELU(),
            do_center=False,
    ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.z0_dim = z0_dim
        self.do_center = do_center

        #self.enc = nn.Sequential(
        #    nn_.ResConv2d(1,16,3,2,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(16,16,3,1,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(16,32,3,2,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(32,32,3,1,padding=1,activation=act),
        #    act,
        #    nn_.ResConv2d(32,32,3,2,padding=1,activation=act),
        #    act,
        #    nn_.Reshape((-1,32*4*4)),
        #    nn_.ResLinear(32*4*4,c_dim),
        #    act
        #)
        self.fc = nn.Sequential(
            nn.Linear(c_dim + z_dim, c_dim, bias=True),
            act,
        )
        self.reparam = NormalDistributionLinear(c_dim, z0_dim)

    def sample(self, mu, logvar):
        return self.reparam.sample_gaussian(mu, logvar)

    #def forward(self, x, z, nz=1):
    #batch_size = x.size(0)
    #x = x.view(batch_size, 1, 28, 28)
    ## rescale
    #if self.do_center:
    #    x = 2*x -1
    def forward(self, ctx, z, nz=1):
        batch_size = ctx.size(0)

        ## enc
        #ctx = self.enc(x)

        # view
        assert z.size(0) == batch_size * nz
        ctx = ctx.unsqueeze(1).expand(-1, nz, -1).contiguous()
        ctx = ctx.view(batch_size * nz, -1)

        # concat
        ctxz = torch.cat([ctx, z], dim=1)

        # forward
        h = self.fc(ctxz)
        mu, logvar = self.reparam(h)

        # sample
        z0 = self.sample(mu, logvar)

        return z0, mu, logvar