def __init__(self): super(Helper, self).__init__() tf.set_random_seed(1234) self.x = tf.get_variable('x', shape=[BATCH_SIZE, X_DIMS], initializer=tf.random_normal_initializer()) self.x2 = tf.get_variable('x2', shape=[N_X, BATCH_SIZE, X_DIMS], initializer=tf.random_normal_initializer()) self.x3 = tf.get_variable('x3', shape=[N_X, N_Z, BATCH_SIZE, X_DIMS], initializer=tf.random_normal_initializer()) self.z = tf.get_variable('z', shape=[BATCH_SIZE, Z_DIMS], initializer=tf.random_normal_initializer()) self.z2 = tf.get_variable('z2', shape=[N_Z, BATCH_SIZE, Z_DIMS], initializer=tf.random_normal_initializer()) self.h_for_p_x = Sequential([ K.layers.Dense(100, activation=tf.nn.relu), DictMapper({'mean': K.layers.Dense(X_DIMS), 'logstd': K.layers.Dense(X_DIMS)}) ]) self.h_for_q_z = Sequential([ K.layers.Dense(100, activation=tf.nn.relu), DictMapper({'mean': K.layers.Dense(Z_DIMS), 'logstd': K.layers.Dense(Z_DIMS)}) ]) # ensure variables created _ = self.h_for_p_x(self.z) _ = self.h_for_q_z(self.x)
def __init__(self, h_for_p_x, h_for_q_z, x_dims, z_dims, std_epsilon=1e-4, name=None, scope=None): if not is_integer(x_dims) or x_dims <= 0: raise ValueError('`x_dims` must be a positive integer') if not is_integer(z_dims) or z_dims <= 0: raise ValueError('`z_dims` must be a positive integer') super(Donut, self).__init__(name=name, scope=scope) with reopen_variable_scope(self.variable_scope): self._vae = VAE( p_z=Normal(mean=tf.zeros([z_dims]), std=tf.ones([z_dims])), p_x_given_z=Normal, q_z_given_x=Normal, h_for_p_x=Sequential([ h_for_p_x, DictMapper( { 'mean': K.layers.Dense(x_dims), 'std': lambda x: (std_epsilon + K.layers.Dense( x_dims, activation=tf.nn.softplus)(x)) }, name='p_x_given_z') ]), h_for_q_z=Sequential([ h_for_q_z, DictMapper( { 'mean': K.layers.Dense(z_dims), 'std': lambda z: (std_epsilon + K.layers.Dense( z_dims, activation=tf.nn.softplus)(z)) }, name='q_z_given_x') ]), ) self._x_dims = x_dims self._z_dims = z_dims
def test_outputs(self): net = DictMapper({ 'a': Lambda(lambda x: x * tf.get_variable('x', initializer=0.)), 'b': lambda x: x * tf.get_variable('y', initializer=0.) }) inputs = tf.placeholder(dtype=tf.float32, shape=[None, 2]) output = net(inputs) self.assertIsInstance(output, dict) self.assertEqual(sorted(output.keys()), ['a', 'b']) for v in six.itervalues(output): self.assertIsInstance(v, tf.Tensor) _ = net(inputs) self.assertEqual( sorted(v.name for v in tf.global_variables()), ['dict_mapper/b/y:0', 'lambda/x:0'] )
def main(): # load mnist data (train_x, train_y), (test_x, test_y) = datasets.load_mnist() # the parameters of this experiment x_dim = train_x.shape[1] z_dim = 2 max_epoch = 10 batch_size = 256 valid_portion = 0.2 # construct the graph with tf.Graph().as_default(), tf.Session().as_default() as session: input_x = tf.placeholder(dtype=tf.float32, shape=(None, x_dim), name='input_x') x_binarized = tf.stop_gradient(sample_input_x(input_x)) batch_size_tensor = tf.shape(input_x)[0] # derive the VAE z_shape = tf.stack([batch_size_tensor, z_dim]) vae = VAE(p_z=Normal(mean=tf.zeros(z_shape), std=tf.ones(z_shape)), p_x_given_z=Bernoulli, q_z_given_x=Normal, h_for_p_x=Sequential([ K.layers.Dense(100, activation=tf.nn.relu), K.layers.Dense(100, activation=tf.nn.relu), DictMapper( {'logits': K.layers.Dense(x_dim, name='x_logits')}) ]), h_for_q_z=Sequential([ tf.to_float, K.layers.Dense(100, activation=tf.nn.relu), K.layers.Dense(100, activation=tf.nn.relu), DictMapper({ 'mean': K.layers.Dense(z_dim, name='z_mean'), 'logstd': K.layers.Dense(z_dim, name='z_logstd'), }) ])) # train the network train(vae.get_training_loss(x_binarized), input_x, train_x, max_epoch, batch_size, valid_portion) # plot the latent space q_net = vae.variational(x_binarized) z_posterior = q_net['z'] z_predict = [] for [batch_x] in DataFlow.arrays([test_x], batch_size=batch_size): z_predict.append( session.run(z_posterior, feed_dict={input_x: batch_x})) z_predict = np.concatenate(z_predict, axis=0) plt.figure(figsize=(8, 6)) plt.scatter(z_predict[:, 0], z_predict[:, 1], c=test_y) plt.colorbar() plt.grid() plt.show()
def test_invalid_key(self): for k in ['.', '', '90ab', 'abc.def']: with pytest.raises( ValueError, match='The key for `DictMapper` must be a valid ' 'Python identifier'): _ = DictMapper({k: lambda x: x})
def test_empty_mapper(self): with pytest.raises(ValueError, match='`mapper` must not be empty'): _ = DictMapper({})