Example #1
0
    def __init__(self):
        super(Helper, self).__init__()
        tf.set_random_seed(1234)

        self.x = tf.get_variable('x',
                                 shape=[BATCH_SIZE, X_DIMS],
                                 initializer=tf.random_normal_initializer())
        self.x2 = tf.get_variable('x2',
                                  shape=[N_X, BATCH_SIZE, X_DIMS],
                                  initializer=tf.random_normal_initializer())
        self.x3 = tf.get_variable('x3',
                                  shape=[N_X, N_Z, BATCH_SIZE, X_DIMS],
                                  initializer=tf.random_normal_initializer())
        self.z = tf.get_variable('z',
                                 shape=[BATCH_SIZE, Z_DIMS],
                                 initializer=tf.random_normal_initializer())
        self.z2 = tf.get_variable('z2',
                                  shape=[N_Z, BATCH_SIZE, Z_DIMS],
                                  initializer=tf.random_normal_initializer())

        self.h_for_p_x = Sequential([
            K.layers.Dense(100, activation=tf.nn.relu),
            DictMapper({'mean': K.layers.Dense(X_DIMS),
                        'logstd': K.layers.Dense(X_DIMS)})
        ])
        self.h_for_q_z = Sequential([
            K.layers.Dense(100, activation=tf.nn.relu),
            DictMapper({'mean': K.layers.Dense(Z_DIMS),
                        'logstd': K.layers.Dense(Z_DIMS)})
        ])

        # ensure variables created
        _ = self.h_for_p_x(self.z)
        _ = self.h_for_q_z(self.x)
Example #2
0
    def __init__(self,
                 h_for_p_x,
                 h_for_q_z,
                 x_dims,
                 z_dims,
                 std_epsilon=1e-4,
                 name=None,
                 scope=None):
        if not is_integer(x_dims) or x_dims <= 0:
            raise ValueError('`x_dims` must be a positive integer')
        if not is_integer(z_dims) or z_dims <= 0:
            raise ValueError('`z_dims` must be a positive integer')

        super(Donut, self).__init__(name=name, scope=scope)
        with reopen_variable_scope(self.variable_scope):
            self._vae = VAE(
                p_z=Normal(mean=tf.zeros([z_dims]), std=tf.ones([z_dims])),
                p_x_given_z=Normal,
                q_z_given_x=Normal,
                h_for_p_x=Sequential([
                    h_for_p_x,
                    DictMapper(
                        {
                            'mean':
                            K.layers.Dense(x_dims),
                            'std':
                            lambda x: (std_epsilon + K.layers.Dense(
                                x_dims, activation=tf.nn.softplus)(x))
                        },
                        name='p_x_given_z')
                ]),
                h_for_q_z=Sequential([
                    h_for_q_z,
                    DictMapper(
                        {
                            'mean':
                            K.layers.Dense(z_dims),
                            'std':
                            lambda z: (std_epsilon + K.layers.Dense(
                                z_dims, activation=tf.nn.softplus)(z))
                        },
                        name='q_z_given_x')
                ]),
            )
        self._x_dims = x_dims
        self._z_dims = z_dims
Example #3
0
    def test_outputs(self):
        net = DictMapper({
            'a': Lambda(lambda x: x * tf.get_variable('x', initializer=0.)),
            'b': lambda x: x * tf.get_variable('y', initializer=0.)
        })
        inputs = tf.placeholder(dtype=tf.float32, shape=[None, 2])
        output = net(inputs)
        self.assertIsInstance(output, dict)
        self.assertEqual(sorted(output.keys()), ['a', 'b'])
        for v in six.itervalues(output):
            self.assertIsInstance(v, tf.Tensor)

        _ = net(inputs)
        self.assertEqual(
            sorted(v.name for v in tf.global_variables()),
            ['dict_mapper/b/y:0',
             'lambda/x:0']
        )
Example #4
0
def main():
    # load mnist data
    (train_x, train_y), (test_x, test_y) = datasets.load_mnist()

    # the parameters of this experiment
    x_dim = train_x.shape[1]
    z_dim = 2
    max_epoch = 10
    batch_size = 256
    valid_portion = 0.2

    # construct the graph
    with tf.Graph().as_default(), tf.Session().as_default() as session:
        input_x = tf.placeholder(dtype=tf.float32,
                                 shape=(None, x_dim),
                                 name='input_x')
        x_binarized = tf.stop_gradient(sample_input_x(input_x))
        batch_size_tensor = tf.shape(input_x)[0]

        # derive the VAE
        z_shape = tf.stack([batch_size_tensor, z_dim])
        vae = VAE(p_z=Normal(mean=tf.zeros(z_shape), std=tf.ones(z_shape)),
                  p_x_given_z=Bernoulli,
                  q_z_given_x=Normal,
                  h_for_p_x=Sequential([
                      K.layers.Dense(100, activation=tf.nn.relu),
                      K.layers.Dense(100, activation=tf.nn.relu),
                      DictMapper(
                          {'logits': K.layers.Dense(x_dim, name='x_logits')})
                  ]),
                  h_for_q_z=Sequential([
                      tf.to_float,
                      K.layers.Dense(100, activation=tf.nn.relu),
                      K.layers.Dense(100, activation=tf.nn.relu),
                      DictMapper({
                          'mean':
                          K.layers.Dense(z_dim, name='z_mean'),
                          'logstd':
                          K.layers.Dense(z_dim, name='z_logstd'),
                      })
                  ]))

        # train the network
        train(vae.get_training_loss(x_binarized), input_x, train_x, max_epoch,
              batch_size, valid_portion)

        # plot the latent space
        q_net = vae.variational(x_binarized)
        z_posterior = q_net['z']
        z_predict = []

        for [batch_x] in DataFlow.arrays([test_x], batch_size=batch_size):
            z_predict.append(
                session.run(z_posterior, feed_dict={input_x: batch_x}))

        z_predict = np.concatenate(z_predict, axis=0)
        plt.figure(figsize=(8, 6))
        plt.scatter(z_predict[:, 0], z_predict[:, 1], c=test_y)
        plt.colorbar()
        plt.grid()
        plt.show()
Example #5
0
 def test_invalid_key(self):
     for k in ['.', '', '90ab', 'abc.def']:
         with pytest.raises(
             ValueError, match='The key for `DictMapper` must be a valid '
                               'Python identifier'):
             _ = DictMapper({k: lambda x: x})
Example #6
0
 def test_empty_mapper(self):
     with pytest.raises(ValueError, match='`mapper` must not be empty'):
         _ = DictMapper({})