Exemplo n.º 1
0
    def __init__(self,
                 nn='v1',
                 name='vae2',
                 z_dim=10,
                 x_dim=24,
                 c_dim=0,
                 warmup=False,
                 var_pen=1,
                 y_dim=0):
        super().__init__()
        # print('ladedah')
        # print('x_dim', x_dim)
        # print('y_dim', y_dim)
        self.name = name
        self.z_dim = z_dim
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.c_dim = c_dim
        self.warmup = warmup
        self.var_pen = var_pen

        nn = getattr(nns, nn)
        self.enc = nn.Encoder(self.z_dim, self.x_dim, self.y_dim, self.c_dim)
        self.dec = nn.Decoder(self.z_dim, self.x_dim, self.y_dim, self.c_dim)

        # Set prior as fixed parameter attached to Module
        self.z_prior_m = torch.nn.Parameter(torch.zeros(1),
                                            requires_grad=False)
        self.z_prior_v = torch.nn.Parameter(torch.ones(1), requires_grad=False)
        self.z_prior = (self.z_prior_m, self.z_prior_v)
Exemplo n.º 2
0
    def __init__(self,
                 nn='v1',
                 name='vae',
                 z_dim=2,
                 x_dim=24,
                 warmup=False,
                 var_pen=1):
        super().__init__()
        self.name = name
        self.z_dim = z_dim
        self.x_dim = x_dim
        self.warmup = warmup
        self.var_pen = var_pen
        # Small note: unfortunate name clash with torch.nn
        # nn here refers to the specific architecture file found in
        # codebase/models/nns/*.py
        nn = getattr(nns, nn)
        self.enc = nn.Encoder(self.z_dim, self.x_dim)
        self.dec = nn.Decoder(self.z_dim, self.x_dim)

        # Set prior as fixed parameter attached to Module
        self.z_prior_m = torch.nn.Parameter(torch.zeros(1),
                                            requires_grad=False)
        self.z_prior_v = torch.nn.Parameter(torch.ones(1), requires_grad=False)
        self.z_prior = (self.z_prior_m, self.z_prior_v)
Exemplo n.º 3
0
    def __init__(self,
                 nn='v1',
                 name='vae',
                 z_dim=2,
                 z_prior_m=None,
                 z_prior_v=None):
        super().__init__()
        self.name = name
        self.z_dim = z_dim
        # Small note: unfortunate name clash with torch.nn
        # nn here refers to the specific architecture file found in
        # codebase/models/nns/*.py
        nn = getattr(nns, nn)
        self.enc = nn.Encoder(self.z_dim)
        self.dec = nn.Decoder(self.z_dim)

        # Set prior as fixed parameter attached to Module
        if z_prior_m is None:
            self.z_prior_m = torch.nn.Parameter(torch.zeros(z_dim),
                                                requires_grad=False)
        else:
            self.z_prior_m = z_prior_m
        if z_prior_v is None:
            self.z_prior_v = torch.nn.Parameter(torch.ones(z_dim),
                                                requires_grad=False)
        else:
            self.z_prior_v = z_prior_v
        self.z_prior = (self.z_prior_m, self.z_prior_v)
Exemplo n.º 4
0
    def __init__(self, nn='v1', encode_dim=None, name='vae', z_dim=2):
        super().__init__()
        self.name = name
        self.z_dim = z_dim
        if nn == 'popv':
            nn = getattr(nns, nn)
            self.enc = nn.Encoder(encode_dim, self.z_dim)
            self.dec = nn.Decoder(encode_dim, self.z_dim)
        else:
            nn = getattr(nns, nn)
            self.enc = nn.Encoder(self.z_dim)
            self.dec = nn.Decoder(self.z_dim)

        # Set prior as fixed parameter attached to Module
        self.z_prior_m = torch.nn.Parameter(torch.zeros(1),
                                            requires_grad=False)
        self.z_prior_v = torch.nn.Parameter(torch.ones(1), requires_grad=False)
        self.z_prior = (self.z_prior_m, self.z_prior_v)
Exemplo n.º 5
0
    def __init__(self, nn='v2', name='fsvae'):
        super().__init__()
        self.name = name
        self.z_dim = 10
        self.y_dim = 10
        nn = getattr(nns, nn)
        self.enc = nn.Encoder(self.z_dim, self.y_dim)
        self.dec = nn.Decoder(self.z_dim, self.y_dim)

        # Set prior as fixed parameter attached to Module
        self.z_prior_m = torch.nn.Parameter(torch.zeros(1), requires_grad=False)
        self.z_prior_v = torch.nn.Parameter(torch.ones(1), requires_grad=False)
        self.z_prior = (self.z_prior_m, self.z_prior_v)
Exemplo n.º 6
0
    def __init__(self, nn='v8', name='vae3d', z_dim=16,device ='cpu', lambda_kl =0.01):
        super().__init__()
        self.name = name
        self.z_dim = z_dim

        nn = getattr(nns, nn)
        self.enc = nn.Encoder(self.z_dim, device =device)
        self.dec = nn.Decoder(self.z_dim, device =device)
        self.lambda_kl = lambda_kl
        # Set prior as fixed parameter attached to Module
        self.z_prior_m = torch.nn.Parameter(torch.zeros(1), requires_grad=False)
        self.z_prior_v = torch.nn.Parameter(torch.ones(1), requires_grad=False)
        self.z_prior = (self.z_prior_m, self.z_prior_v)
Exemplo n.º 7
0
    def __init__(self, nn='v1', z_dim=2, k=500, name='gmvae'):
        super().__init__()
        self.name = name
        self.k = k
        self.z_dim = z_dim
        nn = getattr(nns, nn)
        self.enc = nn.Encoder(self.z_dim)
        self.dec = nn.Decoder(self.z_dim)

        # Mixture of Gaussians prior
        self.z_pre = torch.nn.Parameter(torch.randn(1, 2 * self.k, self.z_dim)
                                        / np.sqrt(self.k * self.z_dim))
        # Uniform weighting
        self.pi = torch.nn.Parameter(torch.ones(k) / k, requires_grad=False)
Exemplo n.º 8
0
 def __init__(self, x_dim, z_dim, z_num, nn='nnet', name='gvae'):
     super().__init__()
     self.name = name
     self.x_dim = x_dim
     self.z_dim = z_dim
     self.z_num = z_num
     nn = getattr(nns, nn)
     self.dec = nn.Decoder(z_dim * z_num, x_dim)
     self.gl_enc = nn.GlobalEncoder(x_dim, z_dim, z_num)
     self.bu_enc = []
     self.td_enc = []
     for n in range(z_num):
         self.bu_enc.append(nn.LocalEncoder(z_dim, z_num))
         self.td_enc.append(nn.LocalEncoder(z_dim, z_num))
     self.mu = nn.Mu(torch.zeros(z_num * (z_num - 1) // 2))
Exemplo n.º 9
0
    def __init__(self, nn='v1', name='ssvae', gen_weight=1, class_weight=100):
        super().__init__()
        self.name = name
        self.z_dim = 64
        self.y_dim = 10
        self.gen_weight = gen_weight
        self.class_weight = class_weight
        nn = getattr(nns, nn)
        self.enc = nn.Encoder(self.z_dim, self.y_dim)
        self.dec = nn.Decoder(self.z_dim, self.y_dim)
        self.cls = nn.Classifier(self.y_dim)

        # Set prior as fixed parameter attached to Module
        self.z_prior_m = torch.nn.Parameter(torch.zeros(1), requires_grad=False)
        self.z_prior_v = torch.nn.Parameter(torch.ones(1), requires_grad=False)
        self.z_prior = (self.z_prior_m, self.z_prior_v)
Exemplo n.º 10
0
    def __init__(self,
                 nn='v1',
                 name='ssvae',
                 rec_weight=1,
                 kl_xy_x_weight=10,
                 kl_xy_y_weight=10,
                 gen_weight=1,
                 class_weight=100,
                 CNN=False):
        super().__init__()
        self.name = name
        self.CNN = CNN
        self.x_dim = 784
        self.z_dim = 64
        self.y_dim = 10

        self.rec_weight = rec_weight
        self.kl_xy_x_weight = kl_xy_x_weight
        self.kl_xy_y_weight = kl_xy_y_weight

        self.gen_weight = gen_weight
        self.class_weight = class_weight

        nn = getattr(nns, nn)

        if CNN:
            self.enc_xy = nn.Encoder_XY(z_dim=self.z_dim, y_dim=self.y_dim)
            self.enc_x = nn.Encoder_X(z_dim=self.z_dim)
            self.enc_y = nn.Encoder_Y(z_dim=self.z_dim, y_dim=self.y_dim)
        else:
            self.enc_xy = nn.Encoder(z_dim=self.z_dim,
                                     y_dim=self.y_dim,
                                     x_dim=self.x_dim)
            self.enc_x = nn.Encoder(z_dim=self.z_dim,
                                    y_dim=0,
                                    x_dim=self.x_dim)
            self.enc_y = nn.Encoder(z_dim=self.z_dim,
                                    y_dim=self.y_dim,
                                    x_dim=0)

        self.dec = nn.Decoder(z_dim=self.z_dim, y_dim=0, x_dim=self.x_dim)

        self.cls = nn.Classifier(y_dim=self.y_dim, input_dim=self.z_dim * 2)