Ejemplo n.º 1
0
    def setup_actor_optimizer(self):
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info('setting up actor optimizer')

        ## as used in Hindsight Experience Replay to stop saturation in tanh
        if self.actor_reg:
            preactivation = tf.get_default_graph().get_tensor_by_name(
                'actor/preactivation:0')
            self.actor_loss = -tf.reduce_mean(
                self.critic_with_actor_tf) + tf.reduce_mean(
                    tf.square(preactivation))

        else:
            self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)

        actor_shapes = [
            var.get_shape().as_list() for var in self.actor.trainable_vars
        ]
        actor_nb_params = sum(
            [reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
        logger.info('  actor shapes: {}'.format(actor_shapes))
        logger.info('  actor params: {}'.format(actor_nb_params))

        self.actor_grads = U.flatgrad(self.actor_loss,
                                      self.actor.trainable_vars,
                                      clip_norm=self.clip_norm)
        self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
                                       beta1=0.9,
                                       beta2=0.999,
                                       epsilon=1e-08)
Ejemplo n.º 2
0
 def setup_critic_optimizer(self):
     logger.info('setting up critic optimizer')
     normalized_critic_target_tf = tf.clip_by_value(
         normalize(self.critic_target, self.ret_rms), self.return_range[0],
         self.return_range[1])
     self.critic_loss = tf.reduce_mean(
         tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
     if self.critic_l2_reg > 0.:
         critic_reg_vars = [
             var for var in self.critic.trainable_vars
             if 'kernel' in var.name and 'output' not in var.name
         ]
         for var in critic_reg_vars:
             logger.info('  regularizing: {}'.format(var.name))
         logger.info('  applying l2 regularization with {}'.format(
             self.critic_l2_reg))
         critic_reg = tc.layers.apply_regularization(
             tc.layers.l2_regularizer(self.critic_l2_reg),
             weights_list=critic_reg_vars)
         self.critic_loss += critic_reg
     critic_shapes = [
         var.get_shape().as_list() for var in self.critic.trainable_vars
     ]
     critic_nb_params = sum(
         [reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
     logger.info('  critic shapes: {}'.format(critic_shapes))
     logger.info('  critic params: {}'.format(critic_nb_params))
     self.critic_grads = U.flatgrad(self.critic_loss,
                                    self.critic.trainable_vars,
                                    clip_norm=self.clip_norm)
     self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars,
                                     beta1=0.9,
                                     beta2=0.999,
                                     epsilon=1e-08,
                                     single_train=self.single_train)
Ejemplo n.º 3
0
class CDQ(object):
    def __init__(self,
                 actor,
                 critic,
                 additional_critic,
                 memory,
                 observation_shape,
                 action_shape,
                 param_noise=None,
                 action_noise=None,
                 gamma=0.99,
                 tau=0.001,
                 normalize_returns=False,
                 enable_popart=False,
                 normalize_observations=True,
                 batch_size=128,
                 observation_range=(-5., 5.),
                 action_range=(-1., 1.),
                 return_range=(-np.inf, np.inf),
                 adaptive_param_noise=True,
                 adaptive_param_noise_policy_threshold=.1,
                 critic_l2_reg=0.,
                 actor_lr=1e-4,
                 critic_lr=1e-3,
                 clip_norm=None,
                 reward_scale=1.,
                 actor_reg=True,
                 select_action=False,
                 skillset=None):

        logger.debug("Parameterized DDPG params")
        logger.debug(str(locals()))
        logger.debug("-" * 20)
        # Inputs.
        self.obs0 = tf.placeholder(tf.float32,
                                   shape=(None, ) + observation_shape,
                                   name='obs0')
        self.obs1 = tf.placeholder(tf.float32,
                                   shape=(None, ) + observation_shape,
                                   name='obs1')
        self.terminals1 = tf.placeholder(tf.float32,
                                         shape=(None, 1),
                                         name='terminals1')
        self.rewards = tf.placeholder(tf.float32,
                                      shape=(None, 1),
                                      name='rewards')
        self.actions = tf.placeholder(tf.float32,
                                      shape=(None, ) + action_shape,
                                      name='actions')
        self.critic_target = tf.placeholder(tf.float32,
                                            shape=(None, 1),
                                            name='critic_target')
        self.param_noise_stddev = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='param_noise_stddev')

        if select_action:
            self.temperature = tf.placeholder(tf.float32,
                                              shape=(),
                                              name='temperature')

        # Parameters.
        self.gamma = gamma
        self.tau = tau
        self.memory = memory
        self.normalize_observations = normalize_observations
        self.normalize_returns = normalize_returns
        self.action_noise = action_noise
        self.param_noise = param_noise
        self.action_range = action_range
        self.return_range = return_range
        self.observation_range = observation_range
        self.critic = critic
        self.critic1 = additional_critic
        self.actor = actor
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.clip_norm = clip_norm
        self.enable_popart = enable_popart
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.stats_sample = None
        self.critic_l2_reg = critic_l2_reg
        self.select_action = select_action
        self.actor_reg = actor_reg

        # Observation normalization.
        if self.normalize_observations:
            with tf.variable_scope('obs_rms'):
                self.obs_rms = RunningMeanStd(shape=observation_shape)
        else:
            self.obs_rms = None
        normalized_obs0 = tf.clip_by_value(normalize(self.obs0, self.obs_rms),
                                           self.observation_range[0],
                                           self.observation_range[1])
        normalized_obs1 = tf.clip_by_value(normalize(self.obs1, self.obs_rms),
                                           self.observation_range[0],
                                           self.observation_range[1])

        # Return normalization.
        if self.normalize_returns:
            with tf.variable_scope('ret_rms'):
                self.ret_rms = RunningMeanStd()
        else:
            self.ret_rms = None

        # action selection constant
        if self.select_action:
            self.skillset = skillset
            W_select = np.zeros(
                (skillset.len, skillset.num_params + skillset.len))
            W_select[:skillset.len, :skillset.len] = 1.
            for i in range(skillset.len):
                starting_idx = skillset.params_start_idx[i] + skillset.len
                ending_idx = starting_idx + skillset.skillset[i].num_params
                W_select[i, starting_idx:ending_idx] = 1.

            print("Selection matrix:%r" % W_select)
            self.W_select = tf.constant(W_select,
                                        dtype=tf.float32,
                                        name="selection_mat")

        # Create target networks.
        target_actor = copy(actor)
        target_actor.name = 'target_actor'
        self.target_actor = target_actor
        target_critic = copy(critic)
        target_critic.name = 'target_critic'
        self.target_critic = target_critic

        target_critic1 = copy(additional_critic)
        target_critic1.name = 'target_critic1'
        self.target_critic1 = target_critic1

        # Create networks and core TF parts that are shared across setup parts.
        target_actor_prediction_next_state_tf = target_actor(normalized_obs1)

        # critic
        self.normalized_critic_tf = critic(normalized_obs0, self.actions)
        self.critic_tf = denormalize(
            tf.clip_by_value(self.normalized_critic_tf, self.return_range[0],
                             self.return_range[1]), self.ret_rms)
        # additional_critic
        self.normalized_critic_tf1 = additional_critic(normalized_obs0,
                                                       self.actions)
        self.critic_tf1 = denormalize(
            tf.clip_by_value(self.normalized_critic_tf1, self.return_range[0],
                             self.return_range[1]), self.ret_rms)

        if self.select_action:
            Q_obs1_0 = denormalize(
                target_critic(
                    normalized_obs1,
                    choose_actions(target_actor_prediction_next_state_tf,
                                   skillset, self.W_select)), self.ret_rms)
            Q_obs1_1 = denormalize(
                target_critic1(
                    normalized_obs1,
                    choose_actions(target_actor_prediction_next_state_tf,
                                   skillset, self.W_select)), self.ret_rms)

        else:
            Q_obs1_0 = denormalize(
                target_critic(normalized_obs1,
                              target_actor_prediction_next_state_tf),
                self.ret_rms)
            Q_obs1_1 = denormalize(
                target_critic1(normalized_obs1,
                               target_actor_prediction_next_state_tf),
                self.ret_rms)

        Q_obs1 = tf.minimum(Q_obs1_0, Q_obs1_1)
        self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1

        # clip the target Q value
        self.target_Q = tf.clip_by_value(self.target_Q, -1 / (1 - gamma), 0)

        if self.select_action:
            self.actor_with_all_params_tf = actor(obs=normalized_obs0,
                                                  temperature=self.temperature)

            # create np and then convert to tf.constant
            # _, selection_mask = choose_actions(self.actor_with_all_params_tf, skillset, self.W_select, True)
            # actor_tf_clone_with_chosen_action = grad_manipulation_op.py_func(grad_manipulation_op.my_identity_func, [self.actor_with_all_params_tf, selection_mask], self.actor_with_all_params_tf.dtype, name="MyIdentity", grad=grad_manipulation_op._custom_identity_grad)

            # in backward pass discrete action for selection will be used as obtained using forward run
            actor_tf_clone_with_chosen_action = choose_actions(
                self.actor_with_all_params_tf, skillset, self.W_select, False)
            self.actor_tf = tf.reshape(actor_tf_clone_with_chosen_action,
                                       tf.shape(self.actor_with_all_params_tf))
        else:
            self.actor_tf = actor(normalized_obs0)

        self.normalized_critic_with_actor_tf = critic(normalized_obs0,
                                                      self.actor_tf,
                                                      reuse=True)
        self.critic_with_actor_tf = denormalize(
            tf.clip_by_value(self.normalized_critic_with_actor_tf,
                             self.return_range[0], self.return_range[1]),
            self.ret_rms)

        # Set up parts.
        if self.param_noise is not None:
            self.setup_param_noise(normalized_obs0)
        self.setup_actor_optimizer()
        self.setup_critic_optimizer()
        if self.normalize_returns and self.enable_popart:
            self.setup_popart()
        self.setup_stats()
        self.setup_target_network_updates()

    def setup_target_network_updates(self):
        actor_init_updates, actor_soft_updates = get_target_updates(
            self.actor.vars, self.target_actor.vars, self.tau)
        critic_init_updates, critic_soft_updates = get_target_updates(
            self.critic.vars, self.target_critic.vars, self.tau)
        self.target_init_updates = [actor_init_updates, critic_init_updates]
        self.target_soft_updates = [actor_soft_updates, critic_soft_updates]

    def setup_param_noise(self, normalized_obs0):
        assert self.param_noise is not None

        # Configure perturbed actor.
        param_noise_actor = copy(self.actor)
        param_noise_actor.name = 'param_noise_actor'
        self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
        logger.debug('setting up param noise')
        self.perturb_policy_ops = get_perturbed_actor_updates(
            self.actor, param_noise_actor, self.param_noise_stddev)

        # Configure separate copy for stddev adoption.
        adaptive_param_noise_actor = copy(self.actor)
        adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
        adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
        self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(
            self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
        self.adaptive_policy_distance = tf.sqrt(
            tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))

    def setup_actor_optimizer(self):
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info('setting up actor optimizer')

        ## as used in Hindsight Experience Replay to stop saturation in tanh
        if self.actor_reg:
            preactivation = tf.get_default_graph().get_tensor_by_name(
                'actor/preactivation:0')
            self.actor_loss = -tf.reduce_mean(
                self.critic_with_actor_tf) + tf.reduce_mean(
                    tf.square(preactivation))

        else:
            self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)

        actor_shapes = [
            var.get_shape().as_list() for var in self.actor.trainable_vars
        ]
        actor_nb_params = sum(
            [reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
        logger.info('  actor shapes: {}'.format(actor_shapes))
        logger.info('  actor params: {}'.format(actor_nb_params))

        self.actor_grads = U.flatgrad(self.actor_loss,
                                      self.actor.trainable_vars,
                                      clip_norm=self.clip_norm)
        self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
                                       beta1=0.9,
                                       beta2=0.999,
                                       epsilon=1e-08)

    def setup_critic_optimizer(self):
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info('setting up critic optimizer')

        normalized_critic_target_tf = tf.clip_by_value(
            normalize(self.critic_target, self.ret_rms), self.return_range[0],
            self.return_range[1])
        self.critic_loss = tf.reduce_mean(tf.square(self.normalized_critic_tf - normalized_critic_target_tf)) \
                            + tf.reduce_mean(tf.square(self.normalized_critic_tf1 - normalized_critic_target_tf))

        if self.critic_l2_reg > 0.:
            critic_reg_vars = [
                var for var in (self.critic.trainable_vars +
                                self.critic1.trainable_vars)
                if 'kernel' in var.name and 'output' not in var.name
            ]
            for var in critic_reg_vars:
                logger.debug('  regularizing: {}'.format(var.name))
            logger.info('  applying l2 regularization with {}'.format(
                self.critic_l2_reg))
            critic_reg = tc.layers.apply_regularization(
                tc.layers.l2_regularizer(self.critic_l2_reg),
                weights_list=critic_reg_vars)
            self.critic_loss += critic_reg
        critic_shapes = [
            var.get_shape().as_list() for var in (self.critic.trainable_vars +
                                                  self.critic1.trainable_vars)
        ]
        critic_nb_params = sum(
            [reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
        logger.info('  critic shapes: {}'.format(critic_shapes))
        logger.info('  critic params: {}'.format(critic_nb_params))
        self.critic_grads = U.flatgrad(
            self.critic_loss,
            (self.critic.trainable_vars + self.critic1.trainable_vars),
            clip_norm=self.clip_norm)
        self.critic_optimizer = MpiAdam(var_list=(self.critic.trainable_vars +
                                                  self.critic1.trainable_vars),
                                        beta1=0.9,
                                        beta2=0.999,
                                        epsilon=1e-08)

    def setup_popart(self):
        # See https://arxiv.org/pdf/1602.07714.pdf for details.
        self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
        new_std = self.ret_rms.std
        self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
        new_mean = self.ret_rms.mean

        self.renormalize_Q_outputs_op = []
        for vs in [self.critic.output_vars, self.target_critic.output_vars]:
            assert len(vs) == 2
            M, b = vs
            assert 'kernel' in M.name
            assert 'bias' in b.name
            assert M.get_shape()[-1] == 1
            assert b.get_shape()[-1] == 1
            self.renormalize_Q_outputs_op += [
                M.assign(M * self.old_std / new_std)
            ]
            self.renormalize_Q_outputs_op += [
                b.assign(
                    (b * self.old_std + self.old_mean - new_mean) / new_std)
            ]

    def setup_stats(self):
        ops = []
        names = []

        if self.normalize_returns:
            ops += [self.ret_rms.mean, self.ret_rms.std]
            names += ['ret_rms_mean', 'ret_rms_std']

        if self.normalize_observations:
            ops += [
                tf.reduce_mean(self.obs_rms.mean),
                tf.reduce_mean(self.obs_rms.std)
            ]
            names += ['obs_rms_mean', 'obs_rms_std']

        ops += [tf.reduce_mean(self.critic_tf)]
        names += ['reference_Q_mean']
        ops += [reduce_std(self.critic_tf)]
        names += ['reference_Q_std']

        ops += [tf.reduce_mean(self.critic_with_actor_tf)]
        names += ['reference_actor_Q_mean']
        ops += [reduce_std(self.critic_with_actor_tf)]
        names += ['reference_actor_Q_std']

        ops += [tf.reduce_mean(self.actor_tf)]
        names += ['reference_action_mean']
        ops += [reduce_std(self.actor_tf)]
        names += ['reference_action_std']

        if self.param_noise:
            ops += [tf.reduce_mean(self.perturbed_actor_tf)]
            names += ['reference_perturbed_action_mean']
            ops += [reduce_std(self.perturbed_actor_tf)]
            names += ['reference_perturbed_action_std']

        self.stats_ops = ops
        self.stats_names = names

    def pi(self, obs, apply_noise=True, compute_Q=True, temperature=1.):
        if self.param_noise is not None and apply_noise:
            actor_tf = self.perturbed_actor_tf
        elif self.action_noise is not None and apply_noise and self.select_action:
            actor_tf = self.actor_with_all_params_tf
        else:
            actor_tf = self.actor_tf
        feed_dict = {self.obs0: [obs]}
        if self.select_action:
            feed_dict[self.temperature] = temperature

        if compute_Q:
            action, q = self.sess.run([actor_tf, self.critic_with_actor_tf],
                                      feed_dict=feed_dict)
        else:
            action = self.sess.run(actor_tf, feed_dict=feed_dict)
            q = None
        action = action.flatten()

        if self.action_noise is not None and apply_noise:
            action = self.action_noise(action)

        if self.select_action:
            # zero out the params of not taken actions
            chosen_action = np.argmax(action[:self.skillset.len])

            continuous_actions = np.zeros_like(action[self.skillset.len:])
            starting_idx = self.skillset.params_start_idx[chosen_action]
            ending_idx = starting_idx + self.skillset.skillset[
                chosen_action].num_params
            continuous_actions[starting_idx:ending_idx] = action[
                self.skillset.len + starting_idx:self.skillset.len +
                ending_idx]

            action[self.skillset.len:] = continuous_actions.copy()

        action = np.clip(action, self.action_range[0], self.action_range[1])

        return action, q

    def store_transition(self, obs0, action, reward, obs1, terminal1):
        reward *= self.reward_scale
        self.memory.append(obs0, action, reward, obs1, terminal1)
        if self.normalize_observations:
            self.obs_rms.update(np.array([obs0]))

    def train(self, summary_var, temperature=1., update_actor=True):
        # Get a batch.
        batch = self.memory.sample(batch_size=self.batch_size)

        if self.normalize_returns and self.enable_popart:
            old_mean, old_std, target_Q = self.sess.run(
                [self.ret_rms.mean, self.ret_rms.std, self.target_Q],
                feed_dict={
                    self.obs1: batch['obs1'],
                    self.rewards: batch['rewards'],
                    self.terminals1: batch['terminals1'].astype('float32'),
                })
            self.ret_rms.update(target_Q.flatten())
            self.sess.run(self.renormalize_Q_outputs_op,
                          feed_dict={
                              self.old_std: np.array([old_std]),
                              self.old_mean: np.array([old_mean]),
                          })

            # Run sanity check. Disabled by default since it slows down things considerably.
            # print('running sanity check')
            # target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={
            #     self.obs1: batch['obs1'],
            #     self.rewards: batch['rewards'],
            #     self.terminals1: batch['terminals1'].astype('float32'),
            # })
            # print(target_Q_new, target_Q, new_mean, new_std)
            # assert (np.abs(target_Q - target_Q_new) < 1e-3).all()
        else:
            target_Q = self.sess.run(self.target_Q,
                                     feed_dict={
                                         self.obs1:
                                         batch['obs1'],
                                         self.rewards:
                                         batch['rewards'],
                                         self.terminals1:
                                         batch['terminals1'].astype('float32'),
                                     })

        # Get all gradients and perform a synced update.
        ops = [
            summary_var, self.actor_grads, self.actor_loss, self.critic_grads,
            self.critic_loss, self.target_Q
        ]
        feed_dict = {
            self.obs0: batch['obs0'],
            self.actions: batch['actions'],
            self.critic_target: target_Q,
            ## just for summary_var
            self.obs1: batch['obs1'],
            self.rewards: batch['rewards'],
            self.terminals1: batch['terminals1'].astype('float32'),
        }

        if self.select_action:
            feed_dict.update({self.temperature: temperature})

        current_summary, actor_grads, actor_loss, critic_grads, critic_loss, _ = self.sess.run(
            ops, feed_dict=feed_dict)

        if update_actor:
            self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr)
        self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr)

        return critic_loss, actor_loss, current_summary

    def initialize(self, sess):
        self.sess = sess
        self.sess.run(tf.global_variables_initializer())
        self.actor_optimizer.sync()
        self.critic_optimizer.sync()
        self.sess.run(self.target_init_updates)

    def update_target_net(self):
        self.sess.run(self.target_soft_updates)

    def get_stats(self, temperature=1.0):

        if self.stats_sample is None:
            # Get a sample and keep that fixed for all further computations.
            # This allows us to estimate the change in value for the same set of inputs.
            self.stats_sample = self.memory.sample(batch_size=self.batch_size)

        feed_dict = {
            self.obs0: self.stats_sample['obs0'],
            self.actions: self.stats_sample['actions'],
        }

        if self.select_action:
            feed_dict[self.temperature] = temperature

        values = self.sess.run(self.stats_ops, feed_dict=feed_dict)

        names = self.stats_names[:]
        assert len(names) == len(values)
        stats = dict(zip(names, values))

        if self.param_noise is not None:
            stats = {**stats, **self.param_noise.get_stats()}

        return stats

    def adapt_param_noise(self):
        if self.param_noise is None:
            return 0.

        # Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
        batch = self.memory.sample(batch_size=self.batch_size)
        self.sess.run(self.perturb_adaptive_policy_ops,
                      feed_dict={
                          self.param_noise_stddev:
                          self.param_noise.current_stddev,
                      })
        distance = self.sess.run(self.adaptive_policy_distance,
                                 feed_dict={
                                     self.obs0:
                                     batch['obs0'],
                                     self.param_noise_stddev:
                                     self.param_noise.current_stddev,
                                 })

        mean_distance = mpi_mean(distance)
        self.param_noise.adapt(mean_distance)
        return mean_distance

    def reset(self):
        # Reset internal state after an episode is complete.
        if self.action_noise is not None:
            self.action_noise.reset()
        if self.param_noise is not None:
            self.sess.run(self.perturb_policy_ops,
                          feed_dict={
                              self.param_noise_stddev:
                              self.param_noise.current_stddev,
                          })
Ejemplo n.º 4
0
class DDPG(object):
    def __init__(self,
                 actor,
                 critic,
                 memory,
                 observation_shape,
                 action_shape,
                 single_train,
                 param_noise=None,
                 action_noise=None,
                 gamma=0.99,
                 tau=0.001,
                 normalize_returns=False,
                 enable_popart=False,
                 normalize_observations=True,
                 batch_size=128,
                 observation_range=(-5., 5.),
                 action_range=(-1., 1.),
                 return_range=(-np.inf, np.inf),
                 adaptive_param_noise=True,
                 adaptive_param_noise_policy_threshold=.1,
                 critic_l2_reg=0.,
                 actor_lr=1e-4,
                 critic_lr=1e-3,
                 clip_norm=None,
                 reward_scale=1.,
                 inverting_grad=False,
                 actor_reg=True):

        logger.info("DDPG params")
        logger.info(str(locals()))
        logger.info("-" * 20)
        # Inputs.

        self.single_train = single_train
        #is the observation space a Tuple space?
        self.tuple_obs = (isinstance(observation_shape[0], list)
                          or isinstance(observation_shape[0], tuple))

        if self.tuple_obs:
            self.obs0 = [
                tf.placeholder(tf.float32,
                               shape=(None, ) + observation_shape_,
                               name='obs0')
                for observation_shape_ in observation_shape
            ]
            self.obs1 = [
                tf.placeholder(tf.float32,
                               shape=(None, ) + observation_shape_,
                               name='obs1')
                for observation_shape_ in observation_shape
            ]
        else:
            self.obs0 = tf.placeholder(tf.float32,
                                       shape=(None, ) + observation_shape,
                                       name='obs0')
            self.obs1 = tf.placeholder(tf.float32,
                                       shape=(None, ) + observation_shape,
                                       name='obs1')

        self.terminals1 = tf.placeholder(tf.float32,
                                         shape=(None, 1),
                                         name='terminals1')
        self.rewards = tf.placeholder(tf.float32,
                                      shape=(None, 1),
                                      name='rewards')
        self.actions = tf.placeholder(tf.float32,
                                      shape=(None, ) + action_shape,
                                      name='actions')
        self.critic_target = tf.placeholder(tf.float32,
                                            shape=(None, 1),
                                            name='critic_target')
        self.param_noise_stddev = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='param_noise_stddev')

        # Parameters.
        self.gamma = gamma
        self.tau = tau
        self.memory = memory
        self.normalize_observations = normalize_observations
        self.normalize_returns = normalize_returns
        self.action_noise = action_noise
        self.param_noise = param_noise
        self.action_range = action_range
        self.return_range = return_range
        self.observation_range = observation_range
        self.critic = critic
        self.actor = actor
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.clip_norm = clip_norm
        self.enable_popart = enable_popart
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.stats_sample = None
        self.critic_l2_reg = critic_l2_reg
        self.inverting_grad = inverting_grad
        self.actor_reg = actor_reg

        self.total_recv = 0
        self.buffers = []

        # Observation normalization.
        if self.normalize_observations:
            with tf.variable_scope('obs_rms'):
                obs_shape = observation_shape[
                    0] if self.tuple_obs else observation_shape
                self.obs_rms = RunningMeanStd(shape=obs_shape)
        else:
            self.obs_rms = None

        if self.tuple_obs:  #normalize only the first item

            normalized_obs0 = self.obs0
            normalized_obs1 = self.obs1

            normalized_obs0[0] = tf.clip_by_value(
                normalize(self.obs0[0], self.obs_rms),
                self.observation_range[0], self.observation_range[1])
            normalized_obs1[0] = tf.clip_by_value(
                normalize(self.obs1[0], self.obs_rms),
                self.observation_range[0], self.observation_range[1])

        else:
            normalized_obs0 = tf.clip_by_value(
                normalize(self.obs0, self.obs_rms), self.observation_range[0],
                self.observation_range[1])
            normalized_obs1 = tf.clip_by_value(
                normalize(self.obs1, self.obs_rms), self.observation_range[0],
                self.observation_range[1])

        # Return normalization.
        if self.normalize_returns:
            with tf.variable_scope('ret_rms'):
                self.ret_rms = RunningMeanStd()
        else:
            self.ret_rms = None

        # Create target networks.
        target_actor = copy(actor)
        target_actor.name = 'target_actor'
        self.target_actor = target_actor
        target_critic = copy(critic)
        target_critic.name = 'target_critic'
        self.target_critic = target_critic

        # Create networks and core TF parts that are shared across setup parts.
        self.normalized_critic_tf = critic(normalized_obs0, self.actions)
        self.critic_tf = denormalize(
            tf.clip_by_value(self.normalized_critic_tf, self.return_range[0],
                             self.return_range[1]), self.ret_rms)
        Q_obs1 = denormalize(
            target_critic(normalized_obs1, target_actor(normalized_obs1)),
            self.ret_rms)
        self.target_Q = self.rewards + (1. - self.terminals1) * gamma * Q_obs1

        # clip the target Q value
        self.target_Q = tf.clip_by_value(self.target_Q, -1 / (1 - gamma), 0)

        self.actor_tf = actor(normalized_obs0)
        if inverting_grad:
            actor_tf_clone_with_invert_grad = my_op.py_func(
                my_op.my_identity_func, [self.actor_tf, -1., 1.],
                self.actor_tf.dtype,
                name="MyIdentity",
                grad=my_op._custom_identity_grad)
            self.actor_tf = tf.reshape(actor_tf_clone_with_invert_grad,
                                       tf.shape(self.actor_tf))
        self.normalized_critic_with_actor_tf = critic(normalized_obs0,
                                                      self.actor_tf,
                                                      reuse=True)
        self.critic_with_actor_tf = denormalize(
            tf.clip_by_value(self.normalized_critic_with_actor_tf,
                             self.return_range[0], self.return_range[1]),
            self.ret_rms)

        # Set up parts.
        if self.param_noise is not None:
            self.setup_param_noise(normalized_obs0)
        self.setup_actor_optimizer()
        self.setup_critic_optimizer()
        if self.normalize_returns and self.enable_popart:
            self.setup_popart()
        self.setup_stats()
        self.setup_target_network_updates()

    def setup_target_network_updates(self):
        actor_init_updates, actor_soft_updates = get_target_updates(
            self.actor.vars, self.target_actor.vars, self.tau)
        critic_init_updates, critic_soft_updates = get_target_updates(
            self.critic.vars, self.target_critic.vars, self.tau)
        self.target_init_updates = [actor_init_updates, critic_init_updates]
        self.target_soft_updates = [actor_soft_updates, critic_soft_updates]

    def setup_param_noise(self, normalized_obs0):
        assert self.param_noise is not None

        # Configure perturbed actor.
        param_noise_actor = copy(self.actor)
        param_noise_actor.name = 'param_noise_actor'
        self.perturbed_actor_tf = param_noise_actor(normalized_obs0)
        logger.info('setting up param noise')
        self.perturb_policy_ops = get_perturbed_actor_updates(
            self.actor, param_noise_actor, self.param_noise_stddev)

        # Configure separate copy for stddev adoption.
        adaptive_param_noise_actor = copy(self.actor)
        adaptive_param_noise_actor.name = 'adaptive_param_noise_actor'
        adaptive_actor_tf = adaptive_param_noise_actor(normalized_obs0)
        self.perturb_adaptive_policy_ops = get_perturbed_actor_updates(
            self.actor, adaptive_param_noise_actor, self.param_noise_stddev)
        self.adaptive_policy_distance = tf.sqrt(
            tf.reduce_mean(tf.square(self.actor_tf - adaptive_actor_tf)))

    def setup_actor_optimizer(self):
        logger.info('setting up actor optimizer')

        ## as used in Hindsight Experience Replay to stop saturation in tanh
        if self.actor_reg:
            preactivation = tf.get_default_graph().get_tensor_by_name(
                'actor/preactivation:0')
            self.actor_loss = -tf.reduce_mean(
                self.critic_with_actor_tf) + tf.reduce_mean(
                    tf.square(preactivation))

        else:
            self.actor_loss = -tf.reduce_mean(self.critic_with_actor_tf)

        actor_shapes = [
            var.get_shape().as_list() for var in self.actor.trainable_vars
        ]
        actor_nb_params = sum(
            [reduce(lambda x, y: x * y, shape) for shape in actor_shapes])
        logger.info('  actor shapes: {}'.format(actor_shapes))
        logger.info('  actor params: {}'.format(actor_nb_params))

        self.actor_grads = U.flatgrad(self.actor_loss,
                                      self.actor.trainable_vars,
                                      clip_norm=self.clip_norm)
        self.actor_optimizer = MpiAdam(var_list=self.actor.trainable_vars,
                                       beta1=0.9,
                                       beta2=0.999,
                                       epsilon=1e-08,
                                       single_train=self.single_train)

    def setup_critic_optimizer(self):
        logger.info('setting up critic optimizer')
        normalized_critic_target_tf = tf.clip_by_value(
            normalize(self.critic_target, self.ret_rms), self.return_range[0],
            self.return_range[1])
        self.critic_loss = tf.reduce_mean(
            tf.square(self.normalized_critic_tf - normalized_critic_target_tf))
        if self.critic_l2_reg > 0.:
            critic_reg_vars = [
                var for var in self.critic.trainable_vars
                if 'kernel' in var.name and 'output' not in var.name
            ]
            for var in critic_reg_vars:
                logger.info('  regularizing: {}'.format(var.name))
            logger.info('  applying l2 regularization with {}'.format(
                self.critic_l2_reg))
            critic_reg = tc.layers.apply_regularization(
                tc.layers.l2_regularizer(self.critic_l2_reg),
                weights_list=critic_reg_vars)
            self.critic_loss += critic_reg
        critic_shapes = [
            var.get_shape().as_list() for var in self.critic.trainable_vars
        ]
        critic_nb_params = sum(
            [reduce(lambda x, y: x * y, shape) for shape in critic_shapes])
        logger.info('  critic shapes: {}'.format(critic_shapes))
        logger.info('  critic params: {}'.format(critic_nb_params))
        self.critic_grads = U.flatgrad(self.critic_loss,
                                       self.critic.trainable_vars,
                                       clip_norm=self.clip_norm)
        self.critic_optimizer = MpiAdam(var_list=self.critic.trainable_vars,
                                        beta1=0.9,
                                        beta2=0.999,
                                        epsilon=1e-08,
                                        single_train=self.single_train)

    def setup_popart(self):
        # See https://arxiv.org/pdf/1602.07714.pdf for details.
        self.old_std = tf.placeholder(tf.float32, shape=[1], name='old_std')
        new_std = self.ret_rms.std
        self.old_mean = tf.placeholder(tf.float32, shape=[1], name='old_mean')
        new_mean = self.ret_rms.mean

        self.renormalize_Q_outputs_op = []
        for vs in [self.critic.output_vars, self.target_critic.output_vars]:
            assert len(vs) == 2
            M, b = vs
            assert 'kernel' in M.name
            assert 'bias' in b.name
            assert M.get_shape()[-1] == 1
            assert b.get_shape()[-1] == 1
            self.renormalize_Q_outputs_op += [
                M.assign(M * self.old_std / new_std)
            ]
            self.renormalize_Q_outputs_op += [
                b.assign(
                    (b * self.old_std + self.old_mean - new_mean) / new_std)
            ]

    def setup_stats(self):
        ops = []
        names = []

        if self.normalize_returns:
            ops += [self.ret_rms.mean, self.ret_rms.std]
            names += ['ret_rms_mean', 'ret_rms_std']

        if self.normalize_observations:
            ops += [
                tf.reduce_mean(self.obs_rms.mean),
                tf.reduce_mean(self.obs_rms.std)
            ]
            names += ['obs_rms_mean', 'obs_rms_std']

        ops += [tf.reduce_mean(self.critic_tf)]
        names += ['reference_Q_mean']
        ops += [reduce_std(self.critic_tf)]
        names += ['reference_Q_std']

        ops += [tf.reduce_mean(self.critic_with_actor_tf)]
        names += ['reference_actor_Q_mean']
        ops += [reduce_std(self.critic_with_actor_tf)]
        names += ['reference_actor_Q_std']

        ops += [tf.reduce_mean(self.actor_tf)]
        names += ['reference_action_mean']
        ops += [reduce_std(self.actor_tf)]
        names += ['reference_action_std']

        if self.param_noise:
            ops += [tf.reduce_mean(self.perturbed_actor_tf)]
            names += ['reference_perturbed_action_mean']
            ops += [reduce_std(self.perturbed_actor_tf)]
            names += ['reference_perturbed_action_std']

        self.stats_ops = ops
        self.stats_names = names

    def pi(self, obs, apply_noise=True, compute_Q=True):

        if self.param_noise is not None and apply_noise:
            actor_tf = self.perturbed_actor_tf
        else:
            actor_tf = self.actor_tf

        if self.tuple_obs:
            feed_dict = {ph: [o_] for (ph, o_) in zip(self.obs0, obs)}
        else:
            feed_dict = {self.obs0: [obs]}

        if compute_Q:
            action, q = self.sess.run([actor_tf, self.critic_with_actor_tf],
                                      feed_dict=feed_dict)
        else:
            action = self.sess.run(actor_tf, feed_dict=feed_dict)
            q = None
        action = action.flatten()
        if self.action_noise is not None and apply_noise:
            # noise = self.action_noise()
            # assert noise.shape == action.shape
            # action += noise
            action = self.action_noise(action)
        action = np.clip(action, self.action_range[0], self.action_range[1])
        return action, q

    def recv_transitions(self):
        assert self.single_train
        comm = MPI.COMM_WORLD
        assert comm.Get_rank() == 0
        status = MPI.Status()
        count = 0

        remaining = comm.Get_size() - 1
        while remaining > 0:
            comm.probe(status=status)
            if status.tag == 0:
                transition = comm.recv(source=status.source, tag=0)
                self.store_transition_to_memory(*transition)
                count += 1
            elif status.tag == 1:
                finished = comm.recv(source=status.source, tag=1)
                assert finished == 'finished'
                remaining -= 1
            else:
                assert False

        self.total_recv += count
        #print("Received %d transitions out of %d" % (count, self.total_recv))

    def finish_sending(self):
        assert self.single_train
        comm = MPI.COMM_WORLD
        assert comm.Get_rank() > 0
        #nonblocking send

        while self.buffers:
            self.buffers.pop().wait()

        comm.send('finished', dest=0, tag=1)
        #comm.isend('finished', dest=0, tag=1)
        #print("%d sent finished" % comm.Get_rank())

    def send_transition(self, transition):
        assert self.single_train
        comm = MPI.COMM_WORLD
        assert comm.Get_rank() > 0
        buf = comm.isend(transition, dest=0, tag=0)
        self.buffers.append(buf)
        self.total_recv += 1
        #if self.total_recv % 100 == 0:
        #    print("Total sent: %d" % self.total_recv)

    def store_transition(self, obs0, action, reward, obs1, terminal1):
        transition = (obs0, action, reward, obs1, terminal1)
        if self.single_train and MPI.COMM_WORLD.Get_rank() > 0:
            self.send_transition(transition)
        else:
            self.store_transition_to_memory(*transition)

    def store_transition_to_memory(self, obs0, action, reward, obs1,
                                   terminal1):
        reward *= self.reward_scale
        self.memory.append(obs0, action, reward, obs1, terminal1)
        if self.normalize_observations:
            if self.tuple_obs:
                self.obs_rms.update(np.array([obs0[0]]))
            else:
                self.obs_rms.update(np.array([obs0]))

    def obs_feed_dict(self, obs_ph, obs_from_memory):
        if self.tuple_obs:
            return dict(zip(obs_ph, obs_from_memory))
        else:
            return {obs_ph: obs_from_memory}

    def train(self, summary_var):

        if self.single_train:
            assert MPI.COMM_WORLD.Get_rank() == 0

        # Get a batch.
        batch = self.memory.sample(batch_size=self.batch_size)

        if self.normalize_returns and self.enable_popart:

            fd = self.obs_feed_dict(self.obs1, batch['obs1'])
            fd.update({
                self.rewards: batch['rewards'],
                self.terminals1: batch['terminals1'].astype('float32')
            })

            old_mean, old_std, target_Q = self.sess.run(
                [self.ret_rms.mean, self.ret_rms.std, self.target_Q],
                feed_dict=fd)

            self.ret_rms.update(target_Q.flatten())
            self.sess.run(self.renormalize_Q_outputs_op,
                          feed_dict={
                              self.old_std: np.array([old_std]),
                              self.old_mean: np.array([old_mean]),
                          })

            # Run sanity check. Disabled by default since it slows down things considerably.
            # print('running sanity check')
            # target_Q_new, new_mean, new_std = self.sess.run([self.target_Q, self.ret_rms.mean, self.ret_rms.std], feed_dict={
            #     self.obs1: batch['obs1'],
            #     self.rewards: batch['rewards'],
            #     self.terminals1: batch['terminals1'].astype('float32'),
            # })
            # print(target_Q_new, target_Q, new_mean, new_std)
            # assert (np.abs(target_Q - target_Q_new) < 1e-3).all()
        else:

            fd = self.obs_feed_dict(self.obs1, batch['obs1'])
            fd.update({
                self.rewards: batch['rewards'],
                self.terminals1: batch['terminals1'].astype('float32')
            })

            target_Q = self.sess.run(self.target_Q, feed_dict=fd)

        # Get all gradients and perform a synced update.
        ops = [
            summary_var, self.actor_grads, self.actor_loss, self.critic_grads,
            self.critic_loss, self.target_Q
        ]

        fd = self.obs_feed_dict(self.obs0, batch['obs0'])
        fd.update(self.obs_feed_dict(self.obs1, batch['obs1']))
        fd.update({
            self.actions: batch['actions'],
            self.critic_target: target_Q,
            ## just for summary_var
            self.rewards: batch['rewards'],
            self.terminals1: batch['terminals1'].astype('float32'),
        })

        current_summary, actor_grads, actor_loss, critic_grads, critic_loss, _ = self.sess.run(
            ops, feed_dict=fd)

        self.actor_optimizer.update(actor_grads, stepsize=self.actor_lr)
        self.critic_optimizer.update(critic_grads, stepsize=self.critic_lr)

        # from ipdb import set_trace
        # set_trace()
        return critic_loss, actor_loss, current_summary

    def initialize(self, sess):
        self.sess = sess
        self.sess.run(tf.global_variables_initializer())
        self.sync()
        self.sess.run(self.target_init_updates)

    def sync(self):
        self.actor_optimizer.sync()
        self.critic_optimizer.sync()
        if self.obs_rms:
            self.obs_rms.sync()
        if self.ret_rms:
            self.ret_rms.sync()

    def update_target_net(self):
        self.sess.run(self.target_soft_updates)

    def get_stats(self):
        if self.stats_sample is None:
            # Get a sample and keep that fixed for all further computations.
            # This allows us to estimate the change in value for the same set of inputs.
            self.stats_sample = self.memory.sample(batch_size=self.batch_size)

        fd = self.obs_feed_dict(self.obs0, self.stats_sample['obs0'])
        fd.update({self.actions: self.stats_sample['actions']})
        values = self.sess.run(self.stats_ops, feed_dict=fd)

        names = self.stats_names[:]
        assert len(names) == len(values)
        stats = dict(zip(names, values))

        if self.param_noise is not None:
            stats = {**stats, **self.param_noise.get_stats()}

        return stats

    def adapt_param_noise(self):
        if self.param_noise is None:
            return 0.

        # Perturb a separate copy of the policy to adjust the scale for the next "real" perturbation.
        batch = self.memory.sample(batch_size=self.batch_size)
        self.sess.run(self.perturb_adaptive_policy_ops,
                      feed_dict={
                          self.param_noise_stddev:
                          self.param_noise.current_stddev,
                      })

        fd = self.obs_feed_dict(self.obs0, batch['obs0'])
        fd.update({self.param_noise_stddev: self.param_noise.current_stddev})
        distance = self.sess.run(self.adaptive_policy_distance, feed_dict=fd)

        mean_distance = mpi_mean(distance)
        self.param_noise.adapt(mean_distance)
        return mean_distance

    def reset(self):
        # Reset internal state after an episode is complete.
        if self.action_noise is not None:
            self.action_noise.reset()

        if self.param_noise is not None:
            self.sess.run(self.perturb_policy_ops,
                          feed_dict={
                              self.param_noise_stddev:
                              self.param_noise.current_stddev,
                          })