Esempio n. 1
0
 def _gauss_log_pi(self, mu, log_sig):
     sigma = tf.exp(log_sig)
     normal = Normal(mu, sigma)
     z = normal.sample()
     actions = self._squash_actions(z)
     gauss_log_prob = normal.log_prob(z)
     log_pi = gauss_log_prob - self._squash_correction(z)
     return log_pi[:, None], actions
    def one_step_IAF(self, a, x):
        z = a[0]
        #log_q = a[1]
        u, enc = x

        input_h = tf.concat([z, enc, u], 1)                          #input should have enc(x), u and previous z
        h = self.q_henc(input_h)                                     #h encoding for iaf
        q_mean, q_var = self.q_transition(z, enc, u)
        p_mean, p_var = self.p_transition(z, u)

        q = MultivariateNormalDiag(q_mean, tf.sqrt(q_var))
        p = MultivariateNormalDiag(p_mean, tf.sqrt(p_var))

        z_step = q.sample()

        log_q = q.log_prob(z_step)                                  #before performing the iaf step

        z_step_iaf, q_var = self.q_transition_IAF(z_step, h)
        log_q = log_q - tf.reduce_sum(tf.log(q_var + 1e-5), axis=1) #after performing the iaf step

        log_p = p.log_prob(z_step_iaf)  #TODO: check if this is correct? Should we be getting the probability of z_step or z_step_iaf?

        return z_step_iaf, log_q, log_p
Esempio n. 3
0
    def one_step_IAF(self, a, x):
        z = a[0]
        log_q = a[1]
        u, enc = x

        z_step, q_var = self.q_transition_IAF(z, enc, u)
        p_mean, p_var = self.p_transition(z, u)

        p = MultivariateNormalDiag(p_mean, tf.sqrt(p_var))

        log_q = log_q - tf.reduce_sum(tf.log(q_var + 1e-5), axis=1)

        log_p = p.log_prob(z_step)

        return z_step, log_q, log_p
Esempio n. 4
0
    def gmm_log_pi(self, log_weights, mu, log_std):

        sigma = tf.exp(log_std)
        normal = Normal(mu, sigma)

        # sample from GMM
        sample_w = tf.stop_gradient(
            tf.multinomial(logits=log_weights, num_samples=1))
        sample_z = tf.stop_gradient(normal.sample())
        mask = tf.one_hot(sample_w[:, 0], depth=self._actor.K)
        z = tf.reduce_sum(sample_z * mask[:, :, None], axis=1)
        action = self.squash_action(z)

        # calculate log policy
        gauss_log_pi = normal.log_prob(z[:, None, :])
        log_pi = tf.reduce_logsumexp(gauss_log_pi + log_weights, axis=-1)
        log_pi -= tf.reduce_logsumexp(log_weights, axis=-1)
        log_pi -= self.get_squash_correction(z)
        log_pi *= self._temp

        return log_pi[:, None], action
Esempio n. 5
0
    def network(self, inputs, pi_raw_action, q_action, phase, num_samples):
        # TODO: Remove alpha (not using multimodal)
        # shared net
        shared_net = tf.contrib.layers.fully_connected(
            inputs,
            self.shared_layer_dim,
            activation_fn=None,
            weights_initializer=tf.contrib.layers.variance_scaling_initializer(
                factor=1.0, mode="FAN_IN", uniform=True),
            weights_regularizer=tf.contrib.layers.l2_regularizer(0.01),
            biases_initializer=tf.contrib.layers.variance_scaling_initializer(
                factor=1.0, mode="FAN_IN", uniform=True))

        shared_net = self.apply_norm(shared_net,
                                     activation_fn=tf.nn.relu,
                                     phase=phase,
                                     layer_num=1)

        # action branch
        pi_net = tf.contrib.layers.fully_connected(
            shared_net,
            self.actor_layer_dim,
            activation_fn=None,
            weights_initializer=tf.contrib.layers.variance_scaling_initializer(
                factor=1.0, mode="FAN_IN", uniform=True),
            weights_regularizer=None,
            biases_initializer=tf.contrib.layers.variance_scaling_initializer(
                factor=1.0, mode="FAN_IN", uniform=True))

        pi_net = self.apply_norm(pi_net,
                                 activation_fn=tf.nn.relu,
                                 phase=phase,
                                 layer_num=2)

        # no activation
        pi_mu = tf.contrib.layers.fully_connected(
            pi_net,
            self.num_modal * self.action_dim,
            activation_fn=None,
            weights_initializer=tf.contrib.layers.variance_scaling_initializer(
                factor=1.0, mode="FAN_IN", uniform=True),
            # weights_initializer=tf.random_uniform_initializer(-3e-3, 3e-3),
            weights_regularizer=None,
            # tf.contrib.layers.l2_regularizer(0.001),
            biases_initializer=tf.contrib.layers.variance_scaling_initializer(
                factor=1.0, mode="FAN_IN", uniform=True))
        # biases_initializer=tf.random_uniform_initializer(-3e-3, 3e-3))

        pi_logstd = tf.contrib.layers.fully_connected(
            pi_net,
            self.num_modal * self.action_dim,
            activation_fn=tf.tanh,
            weights_initializer=tf.random_uniform_initializer(0, 1),
            weights_regularizer=None,
            # tf.contrib.layers.l2_regularizer(0.001),
            biases_initializer=tf.random_uniform_initializer(-3e-3, 3e-3))

        pi_alpha = tf.contrib.layers.fully_connected(
            pi_net,
            self.num_modal,
            activation_fn=tf.tanh,
            weights_initializer=tf.random_uniform_initializer(-3e-3, 3e-3),
            weights_regularizer=None,
            # tf.contrib.layers.l2_regularizer(0.001),
            biases_initializer=tf.random_uniform_initializer(-3e-3, 3e-3))

        # reshape output
        assert (self.num_modal == 1)

        # pi_mu = tf.reshape(pi_mu, [-1, self.num_modal, self.action_dim])
        # pi_logstd = tf.reshape(pi_logstd, [-1, self.num_modal, self.action_dim])
        # pi_alpha = tf.reshape(pi_alpha, [-1, self.num_modal, 1])

        pi_mu = tf.reshape(pi_mu, [-1, self.action_dim])
        pi_logstd = tf.reshape(pi_logstd, [-1, self.action_dim])
        pi_alpha = tf.reshape(pi_alpha, [-1, 1])

        # exponentiate logstd
        # pi_std = tf.exp(tf.scalar_mul(self.sigma_scale, pi_logstd))
        pi_std = tf.exp(self.LOG_STD_MIN + 0.5 *
                        (self.LOG_STD_MAX - self.LOG_STD_MIN) *
                        (pi_logstd + 1))

        # construct MultivariateNormalDiag dist.
        mvn = MultivariateNormalDiag(loc=pi_mu, scale_diag=pi_std)

        if self.actor_update == "reparam":
            # pi = mu + tf.random_normal(tf.shape(mu)) * std
            # logp_pi = self.gaussian_likelihood(pi, mu, log_std)

            # pi_mu: (batch_size, action_dim)

            # (batch_size x num_samples, action_dim)
            # If updating multiple samples
            stacked_pi_mu = tf.expand_dims(pi_mu, 1)
            stacked_pi_mu = tf.tile(stacked_pi_mu, [1, num_samples, 1])
            stacked_pi_mu = tf.reshape(
                stacked_pi_mu,
                (-1,
                 self.action_dim))  # (batch_size * num_samples, action_dim)

            stacked_pi_std = tf.expand_dims(pi_std, 1)
            stacked_pi_std = tf.tile(stacked_pi_std, [1, num_samples, 1])
            stacked_pi_std = tf.reshape(
                stacked_pi_std,
                (-1,
                 self.action_dim))  # (batch_size * num_samples, action_dim)

            noise = tf.random_normal(tf.shape(stacked_pi_mu))

            # (batch_size * num_samples, action_dim)
            pi_raw_samples = stacked_pi_mu + noise * stacked_pi_std
            pi_raw_samples_logprob = self.gaussian_loglikelihood(
                pi_raw_samples, stacked_pi_mu, stacked_pi_std)

            pi_raw_samples = tf.reshape(pi_raw_samples,
                                        (-1, num_samples, self.action_dim))
            pi_raw_samples_logprob = tf.reshape(
                pi_raw_samples_logprob, (-1, num_samples, self.action_dim))

        else:
            pi_raw_samples_og = mvn.sample(num_samples)

            # dim: (batch_size, num_samples, action_dim)
            pi_raw_samples = tf.transpose(pi_raw_samples_og, [1, 0, 2])

            # get raw logprob
            pi_raw_samples_logprob_og = mvn.log_prob(pi_raw_samples_og)
            pi_raw_samples_logprob = tf.transpose(pi_raw_samples_logprob_og,
                                                  [1, 0, 2])

        # apply tanh
        pi_mu = tf.tanh(pi_mu)
        pi_samples = tf.tanh(pi_raw_samples)

        pi_samples_logprob = pi_raw_samples_logprob - tf.reduce_sum(tf.log(
            self.clip_but_pass_gradient(1 - pi_samples**2, l=0, u=1) + 1e-6),
                                                                    axis=-1)

        pi_mu = tf.multiply(pi_mu, self.action_max)
        pi_samples = tf.multiply(pi_samples, self.action_max)

        # compute logprob for input action
        pi_raw_actions_logprob = mvn.log_prob(pi_raw_action)
        pi_action = tf.tanh(pi_raw_action)
        pi_actions_logprob = pi_raw_actions_logprob - tf.reduce_sum(tf.log(
            self.clip_but_pass_gradient(1 - pi_action**2, l=0, u=1) + 1e-6),
                                                                    axis=-1)

        # TODO: Remove alpha
        # compute softmax prob. of alpha
        max_alpha = tf.reduce_max(pi_alpha, axis=1, keepdims=True)
        pi_alpha = tf.subtract(pi_alpha, max_alpha)
        pi_alpha = tf.exp(pi_alpha)

        normalize_alpha = tf.reciprocal(
            tf.reduce_sum(pi_alpha, axis=1, keepdims=True))
        pi_alpha = tf.multiply(normalize_alpha, pi_alpha)

        # Q branch
        with tf.variable_scope('qf'):
            q_actions_prediction = self.q_network(shared_net, q_action, phase)
        with tf.variable_scope('qf', reuse=True):
            # if len(tf.shape(pi_samples)) == 3:
            pi_samples_reshaped = tf.reshape(
                pi_samples, (self.batch_size * num_samples, self.action_dim))
            # else:
            #     assert(len(tf.shape(pi_samples)) == 2)
            #     pi_samples_reshaped = pi_samples
            q_samples_prediction = self.q_network(shared_net,
                                                  pi_samples_reshaped, phase)

        # print(pi_raw_action, pi_action)
        # print(pi_raw_actions_logprob, pi_raw_actions_logprob)
        # print(pi_action, pi_actions_logprob)

        return pi_alpha, pi_mu, pi_std, pi_raw_samples, pi_samples, pi_samples_logprob, pi_actions_logprob, q_samples_prediction, q_actions_prediction
    def __init__(self, n_obs, n_control, n_latent, n_enc, chkpoint_file=None):

        self.learning_rate = tf.placeholder(tf.float32)
        self.annealing_rate = tf.placeholder(tf.float32)
        
        # Dimensions
        self.n_output = n_obs
        self.n_obs = n_obs
        self.n_control = n_control
        self.n_latent = n_latent
        self.n_enc = n_enc

        # The placeholder from the input
        self.x = tf.placeholder(tf.float32, [None, None, self.n_obs], name="X")
        self.u = tf.placeholder(tf.float32, [None, None, self.n_control], name="U")

        # Initialize p(z0), p(x|z), q(z'|enc, u, z) and p(z'|z) as well as the mlp that
        # generates a low dimensional encoding of x, called enc
        self._init_generative_dist()
        self._init_start_dist()
        self._init_encoding_mlp()
        self.transition = BaselineTransitionNoKL(self.n_latent, self.n_enc, self.n_control)
        
        # Get the encoded representation of the observations (this makes sense when observations are highdimensional images for example)
        enc = self.get_enc_rep(self.x)
        
        # Get the latent start state
        q0 = self.get_start_dist(self.x[0])
        z0 = q0.sample()
        log_q0 = q0.log_prob(z0)
        p0 = MultivariateNormalDiag(tf.zeros(tf.shape(z0)), tf.ones(tf.shape(z0)))
        log_p0 = p0.log_prob(z0)
                               
        # Trajectory rollout in latent space + calculation of KL(q(z'|enc, u, z) || p(z'|z))
        z, log_q, log_p = tf.scan(self.transition.one_step_IAF, (self.u[:-1], enc[1:]),  (z0, log_q0, log_p0))
        self.z = tf.concat([[z0], z], 0)
        
        # Get the generative distribution p(x|z) + calculation of the reconstruntion error
        # TODO: Including x[0], revert if doesn't work
        # px = self.get_generative_dist(z)
        # rec_loss = -px.log_prob(self.x[1:])

        # TODO: Including x[0], Remove if doesn't work
        px = self.get_generative_dist(self.z)
        rec_loss = -px.log_prob(self.x)

        self.px_mean = px.mean()
        
        # Generating trajectories given only an initial observation
        gen_z = tf.scan(self.transition.gen_one_step, self.u[:-1], z0)
        self.gen_z = tf.concat([[z0], gen_z], 0)
        gen_px = self.get_generative_dist(self.gen_z)
        self.gen_x_mean = gen_px.mean()

        # TODO: Including x[0], Remove if doesn't work
        log_p = tf.concat([[log_p0], log_p], 0)
        log_q = tf.concat([[log_q0], log_q], 0)

        # Create the losses
        # self.rec_loss = rec_loss
        self.log_p = log_p * self.annealing_rate
        self.log_q = log_q * self.annealing_rate
        self.rec_loss = tf.reduce_mean(rec_loss)
        self.kl_loss = tf.reduce_mean( self.log_q - self.log_p)



        # self.total_loss = tf.reduce_mean(rec_loss + self.log_q - self.log_p)
        self.total_loss = self.total_loss = self.kl_loss + self.rec_loss

        # Use the Adam optimizer with clipped gradients
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
        grads_and_vars = self.optimizer.compute_gradients(self.total_loss)
        capped_grads_and_vars = [(tf.clip_by_value(grad, -1, 1), var) if grad is not None else (grad, var) for 
                                 (grad, var) in grads_and_vars]
        self.optimizer = self.optimizer.apply_gradients(capped_grads_and_vars)

        # Save weights
        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

        # Launch the session
        self.sess = tf.InteractiveSession()
        self.sess.run(tf.global_variables_initializer())
        if chkpoint_file:
            utils.load_checkpoint(self.sess, chkpoint_file)