Example #1
0
 def _initialize_networks(self):
     """ Initialize all model networks """
     if self.TYPE_PX == 'Gaussian':
         self.Pz_x = dgm._init_Gauss_net(self.Z_DIM,
                                         self.NUM_HIDDEN,
                                         self.X_DIM,
                                         'Pz_x_',
                                         bn=self.batchnorm)
     elif self.TYPE_PX == 'Bernoulli':
         self.Pz_x = dgm._init_Cat_net(self.Z_DIM,
                                       self.NUM_HIDDEN,
                                       self.X_DIM,
                                       'Pz_x_',
                                       bn=self.batchnorm)
     self.Pzx_y = dgm._init_Cat_bnn(self.Z_DIM + self.X_DIM,
                                    self.NUM_HIDDEN, self.NUM_CLASSES,
                                    'Pzx_y', self.initVar)
     self.Wtilde = self._init_Wtilde(self.Z_DIM + self.X_DIM,
                                     self.NUM_HIDDEN,
                                     self.NUM_CLASSES,
                                     'W_tilde_',
                                     bn=self.batchnorm)
     self.Qxy_z = dgm._init_Gauss_net(self.X_DIM + self.NUM_CLASSES,
                                      self.NUM_HIDDEN,
                                      self.Z_DIM,
                                      'Qxy_z',
                                      bn=self.batchnorm)
     self.Qx_y = dgm._init_Cat_net(self.X_DIM,
                                   self.NUM_HIDDEN,
                                   self.NUM_CLASSES,
                                   'Qx_y',
                                   bn=self.batchnorm)
Example #2
0
    def _initialize_networks(self):
    	""" Initialize all model networks """
	if self.TYPE_PX == 'Gaussian':
      	    self.Pza_x = dgm._init_Gauss_net(self.Z_DIM+self.A_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pza_x_', self.batchnorm)
	elif self.TYPE_PX == 'Bernoulli':
	    self.Pza_x = dgm._init_Cat_net(self.Z_DIM+self.A_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pza_x_', self.batchnorm)
    	self.Pzax_y = dgm._init_Cat_net(self.Z_DIM+self.A_DIM+self.X_DIM, self.NUM_HIDDEN, self.NUM_CLASSES, 'Pzax_y_', self.batchnorm)
	self.Pz_a = dgm._init_Gauss_net(self.Z_DIM, self.NUM_HIDDEN, self.A_DIM, 'Pz_a_', self.batchnorm) 
    	self.Qxya_z = dgm._init_Gauss_net(self.X_DIM+self.A_DIM+self.NUM_CLASSES, self.NUM_HIDDEN, self.Z_DIM, 'Qxya_z_', self.batchnorm)
    	self.Qx_a = dgm._init_Gauss_net(self.X_DIM, self.NUM_HIDDEN, self.A_DIM, 'Qx_a_', self.batchnorm)
    	self.Qxa_y = dgm._init_Cat_net(self.X_DIM+self.A_DIM, self.NUM_HIDDEN, self.NUM_CLASSES, 'Qxa_y_', self.batchnorm)
Example #3
0
 def _initialize_networks(self):
     """ Initialize all model networks """
     if self.TYPE_PX == 'Gaussian':
         self.Pzy_x = dgm._init_Gauss_net(self.Z_DIM + self.NUM_CLASSES,
                                          self.NUM_HIDDEN, self.X_DIM,
                                          'Pzy_x', self.batchnorm)
     elif self.TYPE_PX == 'Bernoulli':
         self.Pzy_x = dgm._init_Cat_net(self.Z_DIM + self.NUM_CLASSES,
                                        self.NUM_HIDDEN, self.X_DIM,
                                        'Pzy_x', self.batchnorm)
     self.Py = tf.constant(
         (1. / self.NUM_CLASSES) * np.ones(shape=(self.NUM_CLASSES, )),
         dtype=tf.float32)
     self.Qxy_z = dgm._init_Gauss_net(self.X_DIM + self.NUM_CLASSES,
                                      self.NUM_HIDDEN, self.Z_DIM, 'Qxy_z',
                                      self.batchnorm)
     self.Qx_y = dgm._init_Cat_net(self.X_DIM, self.NUM_HIDDEN,
                                   self.NUM_CLASSES, 'Qx_y', self.batchnorm)
Example #4
0
    def _initialize_networks(self):
	if self.TYPE_PX=='Gaussian':
    	    self.Pz_x = dgm._init_Gauss_net(self.Z_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pz_x_', self.batchnorm)
	elif self.TYPE_PX=='Bernoulli':
    	    self.Pz_x = dgm._init_Cat_net(self.Z_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pz_x_', self.batchnorm)	   
    	self.Qx_z = dgm._init_Gauss_net(self.X_DIM, self.NUM_HIDDEN, self.Z_DIM, 'Qx_z_', self.batchnorm)