def step(self, obs):
     if obs.ndim < 2:
         obs = obs[np.newaxis, :]
     action_probs = self.actor(obs)
     dist = Categorical(probs=action_probs)
     action = dist.sample()
     return action.numpy()[0], dist.log_prob(action)
 def call(self, state):
     if state.ndim < 2:
         state = state[np.newaxis, :]
     action_probs = self.net(state)
     dist = Categorical(probs=action_probs)
     action = dist.sample()
     log_pi = dist.log_prob(action)
     return action.numpy(), log_pi
def train():
    # Load MNIST data.
    data = input_data.read_data_sets(FLAGS.data_dir + '/MNIST', one_hot=True)

    # Create encoder graph.
    with tf.variable_scope("encoder"):
        inputs = tf.placeholder(tf.float32,
                                shape=[None, 28 * 28],
                                name='inputs')
        tau = tf.placeholder(tf.float32, shape=[], name='temperature')
        logits = encoder(inputs)
        z = gumbel_softmax(
            logits, tau,
            hard=False)  # (batch_size, num_cat_dists, num_classes)

    # Create decoder graph.
    with tf.variable_scope("decoder"):
        p_x_given_z = decoder(z)

    with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
        categorical = Categorical(probs=np.ones(FLAGS.num_classes) /
                                  FLAGS.num_classes)
        z = categorical.sample(
            sample_shape=[FLAGS.batch_size, FLAGS.num_cat_dists])
        z = tf.one_hot(z, depth=FLAGS.num_classes)
        p_x_given_z_eval = decoder(z)

    # Define loss function and train opeator.
    # NOTE: Categorically uniform prior p(z) is assumed.
    # NOTE: Also, in this case, KL becomes negative entropy.
    # NOTE: Summation becomes KLD over whole distribution q(z|x) since z is assumed to be elementwise independent.
    KL = -tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=tf.nn.softmax(logits), logits=logits),
                        axis=1)
    ELBO = tf.reduce_sum(p_x_given_z.log_prob(inputs), axis=1) - KL
    loss = tf.reduce_mean(-ELBO)
    train_op = tf.train.AdamOptimizer(
        learning_rate=FLAGS.learning_rate).minimize(loss)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        temperature = 1.0
        for i in tqdm(range(1, FLAGS.num_iters)):
            np_x, np_y = data.train.next_batch(FLAGS.batch_size)
            _, np_loss = sess.run([train_op, loss], {
                inputs: np_x,
                tau: temperature
            })
            if i % 1000 == 0:
                temperature = np.maximum(FLAGS.min_temp,
                                         np.exp(-FLAGS.anneal_rate * i))
                print('Temperature updated to {}\n'.format(temperature))
            if i % 5000 == 1:
                print('Iteration {}\nELBO: {}\n'.format(i, -np_loss))

        # Plot results.
        x_mean = p_x_given_z.mean()
        batch = data.test.next_batch(FLAGS.batch_size)
        np_x = sess.run(x_mean, {inputs: batch[0], tau: FLAGS.min_temp})

        x_mean_eval = p_x_given_z_eval.mean()
        np_x_eval = sess.run(x_mean_eval)

        plot_squares(batch[0], np_x, np_x_eval, 8)