def _initialize_networks(self): """ Initialize all model networks """ if self.TYPE_PX == 'Gaussian': self.Pz_x = dgm._init_Gauss_net(self.Z_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pz_x_', bn=self.batchnorm) elif self.TYPE_PX == 'Bernoulli': self.Pz_x = dgm._init_Cat_net(self.Z_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pz_x_', bn=self.batchnorm) self.Pzx_y = dgm._init_Cat_bnn(self.Z_DIM + self.X_DIM, self.NUM_HIDDEN, self.NUM_CLASSES, 'Pzx_y', self.initVar) self.Wtilde = self._init_Wtilde(self.Z_DIM + self.X_DIM, self.NUM_HIDDEN, self.NUM_CLASSES, 'W_tilde_', bn=self.batchnorm) self.Qxy_z = dgm._init_Gauss_net(self.X_DIM + self.NUM_CLASSES, self.NUM_HIDDEN, self.Z_DIM, 'Qxy_z', bn=self.batchnorm) self.Qx_y = dgm._init_Cat_net(self.X_DIM, self.NUM_HIDDEN, self.NUM_CLASSES, 'Qx_y', bn=self.batchnorm)
def _initialize_networks(self): """ Initialize all model networks """ if self.TYPE_PX == 'Gaussian': self.Pza_x = dgm._init_Gauss_net(self.Z_DIM+self.A_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pza_x_', self.batchnorm) elif self.TYPE_PX == 'Bernoulli': self.Pza_x = dgm._init_Cat_net(self.Z_DIM+self.A_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pza_x_', self.batchnorm) self.Pzax_y = dgm._init_Cat_net(self.Z_DIM+self.A_DIM+self.X_DIM, self.NUM_HIDDEN, self.NUM_CLASSES, 'Pzax_y_', self.batchnorm) self.Pz_a = dgm._init_Gauss_net(self.Z_DIM, self.NUM_HIDDEN, self.A_DIM, 'Pz_a_', self.batchnorm) self.Qxya_z = dgm._init_Gauss_net(self.X_DIM+self.A_DIM+self.NUM_CLASSES, self.NUM_HIDDEN, self.Z_DIM, 'Qxya_z_', self.batchnorm) self.Qx_a = dgm._init_Gauss_net(self.X_DIM, self.NUM_HIDDEN, self.A_DIM, 'Qx_a_', self.batchnorm) self.Qxa_y = dgm._init_Cat_net(self.X_DIM+self.A_DIM, self.NUM_HIDDEN, self.NUM_CLASSES, 'Qxa_y_', self.batchnorm)
def _initialize_networks(self): """ Initialize all model networks """ if self.TYPE_PX == 'Gaussian': self.Pzy_x = dgm._init_Gauss_net(self.Z_DIM + self.NUM_CLASSES, self.NUM_HIDDEN, self.X_DIM, 'Pzy_x', self.batchnorm) elif self.TYPE_PX == 'Bernoulli': self.Pzy_x = dgm._init_Cat_net(self.Z_DIM + self.NUM_CLASSES, self.NUM_HIDDEN, self.X_DIM, 'Pzy_x', self.batchnorm) self.Py = tf.constant( (1. / self.NUM_CLASSES) * np.ones(shape=(self.NUM_CLASSES, )), dtype=tf.float32) self.Qxy_z = dgm._init_Gauss_net(self.X_DIM + self.NUM_CLASSES, self.NUM_HIDDEN, self.Z_DIM, 'Qxy_z', self.batchnorm) self.Qx_y = dgm._init_Cat_net(self.X_DIM, self.NUM_HIDDEN, self.NUM_CLASSES, 'Qx_y', self.batchnorm)
def _initialize_networks(self): if self.TYPE_PX=='Gaussian': self.Pz_x = dgm._init_Gauss_net(self.Z_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pz_x_', self.batchnorm) elif self.TYPE_PX=='Bernoulli': self.Pz_x = dgm._init_Cat_net(self.Z_DIM, self.NUM_HIDDEN, self.X_DIM, 'Pz_x_', self.batchnorm) self.Qx_z = dgm._init_Gauss_net(self.X_DIM, self.NUM_HIDDEN, self.Z_DIM, 'Qx_z_', self.batchnorm)