示例#1
0
 def setup_staging_areas(self):
     for idx, device in enumerate(self._devices):
         with tf.device(device):
             inputs = self._input.get_input_tensors()
             dtypes = [x.dtype for x in inputs]
             stage = StagingArea(dtypes, shapes=None)
             self._stage_ops.append(stage.put(inputs))
             self._areas.append(stage)
             outputs = stage.get()
             for vin, vout in zip(inputs, outputs):
                 vout.set_shape(vin.get_shape())
             self._unstage_ops.append(outputs)
示例#2
0
 def stage_data(self, batch, memory_gb=1, n_threads=4):
     ''''''
     with tf.device('/gpu:0'):
         dtypes = [t.dtype for t in batch]
         shapes = [t.get_shape() for t in batch]
         SA = StagingArea(dtypes,
                          shapes=shapes,
                          memory_limit=memory_gb * 1e9)
         get, put, clear = SA.get(), SA.put(batch), SA.clear()
     tf.train.add_queue_runner(
         tf.train.QueueRunner(queue=SA,
                              enqueue_ops=[put] * n_threads,
                              close_op=clear,
                              cancel_op=clear))
     return get
示例#3
0
 def setup_staging_areas(self):
     for idx, device in enumerate(self._devices):
         with tf.device(device):
             inputs = self._input.get_input_tensors()
             dtypes = [x.dtype for x in inputs]
             stage = StagingArea(dtypes, shapes=None)
             self._stage_ops.append(stage.put(inputs))
             self._areas.append(stage)
             outputs = stage.get()
             if isinstance(
                     outputs,
                     tf.Tensor):  # when size=1, TF doesn't return a list
                 outputs = [outputs]
             for vin, vout in zip(inputs, outputs):
                 vout.set_shape(vin.get_shape())
             self._unstage_ops.append(outputs)
示例#4
0
    def _prepare_staging(self):
        with tf.variable_scope('staging', reuse=tf.AUTO_REUSE):
            staging_area_tf = StagingArea(
                dtypes=[tf.float32 for _ in self._stage_shapes.keys()],
                shapes=[(None, *shape)
                        for shape in self._stage_shapes.values()])
            input_ph_tf = [
                tf.placeholder(tf.float32, shape=(None, *shape))
                for shape in self._stage_shapes.values()
            ]
            staging_op_tf = staging_area_tf.put(input_ph_tf)

            batch_tf = OrderedDict([
                (key, batch_item) for key, batch_item in zip(
                    self._stage_shapes.keys(), staging_area_tf.get())
            ])

        return staging_area_tf, input_ph_tf, staging_op_tf, batch_tf
class DDPG_PDDL(Policy):
    @store_args
    def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
                 Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
                 rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
                 sample_transitions, gamma, n_preds, reuse=False, **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        Policy.__init__(self, input_dims, T, rollout_batch_size, **kwargs)

        self.hidden = hidden
        self.layers = layers
        self.max_u = max_u
        self.network_class = network_class
        self.sample_transitions = sample_transitions
        self.scope = scope
        self.subtract_goals = subtract_goals
        self.relative_goals = relative_goals
        self.clip_obs = clip_obs
        self.Q_lr = Q_lr
        self.pi_lr = pi_lr
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.clip_pos_returns = clip_pos_returns
        self.gamma = gamma
        self.polyak = polyak
        self.clip_return = clip_return
        self.norm_eps = norm_eps
        self.norm_clip = norm_clip
        self.action_l2 = action_l2
        self.n_preds = n_preds
        self.rep_lr = Q_lr
        if self.clip_return is None:
            self.clip_return = np.inf
        self.create_actor_critic = import_function(self.network_class)
        self.rep_network = import_function(kwargs['rep_network_class'])


        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *self.input_shapes[key])
                         for key, val in self.input_shapes.items()}
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T+1, self.dimg)

        buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)

        # Creat rep. network
        with tf.variable_scope(self.scope):
            self._create_rep_network(reuse=reuse)
        self.obs2preds_buffer = Obs2PredsBuffer(buffer_len=2000)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
                    compute_Q=False, exploit=True):

        noise_eps = noise_eps if not exploit else 0.
        random_eps = random_eps if not exploit else 0.

        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf: o.reshape(-1, self.dimo),
            policy.g_tf: g.reshape(-1, self.dimg),
            policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(*u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
            transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _sync_rep_optimizers(self):
        self.rep_adam.sync()
        # self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf,
            self.main.Q_pi_tf,
            self.Q_grad_tf,
            self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        import os
        # print("PID: {}. Updating AC.".format(os.getpid()))
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        transitions = self.buffer.sample(self.batch_size)
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g)
        transitions_batch = [transitions[key] for key in self.stage_shapes.keys()]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def train_representation(self):
        rep_batch_size = 64
        batch = self.obs2preds_buffer.sample_batch(rep_batch_size)
        indexes = batch['indexes']
        feed_dict = {self.obs2preds_model.inputs_o: batch['obs'],
                                 self.obs2preds_model.inputs_g: batch['goals'],
                                 self.obs2preds_model.preds: batch['preds']}

        rep_grad = self.sess.run([self.rep_grad_tf], feed_dict=feed_dict)[0]
        self.rep_adam.update(rep_grad, self.rep_lr)
        # opti_res, celoss, celosses = self.sess.run([self.obs2preds_model.optimizer,
        #                                   self.obs2preds_model.celoss,
        #                                   self.obs2preds_model.celosses],
        #               feed_dict=feed_dict)
        #
        # celosses = np.mean(celosses, axis=-1)
        _, celosses_after = self.predict_representation(batch)
        celoss = np.mean(celosses_after)
        return celoss, celosses_after, indexes

    def predict_representation(self, batch):
        feed_dict = {self.obs2preds_model.inputs_o: batch['obs'],
                     self.obs2preds_model.inputs_g: batch['goals']}
        pred_dist = self.sess.run([self.obs2preds_model.prob_out],
                                  feed_dict=feed_dict)
        losses = None
        if 'preds' in batch:
            preds = batch['preds']
            if len(preds.shape) != 3:
                preds_probdist = np.zeros(shape=[preds.shape[0], preds.shape[1], 2])
                for j,p in enumerate(preds):
                    for i, v in enumerate(p):
                        preds_probdist[j][i][int(v)] = 1
                preds = preds_probdist
            feed_dict.update({self.obs2preds_model.preds: preds})
            pred_dist, loss = self.sess.run([self.obs2preds_model.prob_out, self.obs2preds_model.celosses],
                                      feed_dict=feed_dict)
            loss = np.mean(loss, axis=-1)

            losses = np.reshape(loss,newshape=(preds.shape[0]))
        preds = prob_dist2discrete(pred_dist)
        return preds, losses



    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope)
        return res

    def _create_rep_network(self, reuse=False):
        self.obs2preds_model = self.rep_network(self.n_preds, self.dimo, self.dimg)
        self.rep_loss_tf = tf.reduce_mean(self.obs2preds_model.celoss)
        rep_grads_tf = tf.gradients(self.rep_loss_tf, self._vars('obs2preds'))
        self.rep_grad_tf = flatten_grads(grads=rep_grads_tf, var_list=self._vars('obs2preds'))
        self.rep_adam = MpiAdam(self._vars('obs2preds'), scale_grad_by_procs=False)
        self._sync_rep_optimizers()

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([(key, batch[i])
                                for i, key in enumerate(self.stage_shapes.keys())])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(
                target_batch_tf, net_type='target', **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        # [print(key, ": ", item) for key,item in self.__dict__.items()]
        excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats',
                             'main', 'target', 'lock', 'env', 'sample_transitions',
                             'stage_shapes', 'create_actor_critic',
                             'obs2preds_buffer', 'obs2preds_model']

        state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])}
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name and 'obs2preds_buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name and 'obs2preds_buffer' not in x.name]
        assert(len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
示例#6
0
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class_actor_critic,
                 network_class_discriminator,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 mi_lr,
                 sk_lr,
                 r_scale,
                 mi_r_scale,
                 sk_r_scale,
                 et_r_scale,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 sample_transitions,
                 gamma,
                 env_name,
                 max_timesteps,
                 pretrain_weights,
                 finetune_pi,
                 mi_prioritization,
                 sac,
                 reuse=False,
                 history_len=10000,
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(
            self.network_class_actor_critic)
        self.create_discriminator = import_function(
            self.network_class_discriminator)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimz = self.input_dims['z']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        self.env_name = env_name

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        stage_shapes['w'] = (None, )
        stage_shapes['m'] = (None, )
        stage_shapes['s'] = (None, )
        stage_shapes['m_w'] = ()
        stage_shapes['s_w'] = ()
        stage_shapes['r_w'] = ()
        stage_shapes['e_w'] = ()
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(pretrain_weights,
                                 mi_prioritization,
                                 reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.T if key != 'o' else self.T + 1, *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T + 1, self.dimg)
        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size

        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions, mi_prioritization)

        self.mi_r_history = deque(maxlen=history_len)
        self.gl_r_history = deque(maxlen=history_len)
        self.sk_r_history = deque(maxlen=history_len)
        self.et_r_history = deque(maxlen=history_len)
        self.mi_current = 0
        self.finetune_pi = finetune_pi

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self,
                    o,
                    z,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        if self.sac:
            vals = [policy.mu_tf]
        else:
            vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]

        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.z_tf:
            z.reshape(-1, self.dimz),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)

        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        # update the mutual information reward into the episode batch
        episode_batch['m'] = np.empty([episode_batch['o'].shape[0], 1])
        episode_batch['s'] = np.empty([episode_batch['o'].shape[0], 1])
        # #

        self.buffer.store_episode(episode_batch, self)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)

            transitions = self.sample_transitions(self, False, episode_batch,
                                                  num_normalizing_transitions,
                                                  0, 0, 0)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()
        self.mi_adam.sync()
        self.sk_adam.sync()

    def _grads_mi(self, data):
        mi, mi_grad = self.sess.run([
            self.main_ir.mi_tf,
            self.mi_grad_tf,
        ],
                                    feed_dict={self.o_tau_tf: data})
        return mi, mi_grad

    def _grads_sk(self, o_s_batch, z_s_batch):
        sk, sk_grad = self.sess.run([
            self.main_ir.sk_tf,
            self.sk_grad_tf,
        ],
                                    feed_dict={
                                        self.main_ir.o_tf: o_s_batch,
                                        self.main_ir.z_tf: z_s_batch
                                    })
        return sk, sk_grad

    def _grads(self):
        critic_loss, actor_loss, Q_grad, pi_grad, neg_logp_pi, e_w = self.sess.run(
            [
                self.Q_loss_tf,
                self.main.Q_pi_tf,
                self.Q_grad_tf,
                self.pi_grad_tf,
                self.main.neg_logp_pi_tf,
                self.e_w_tf,
            ])
        return critic_loss, actor_loss, Q_grad, pi_grad, neg_logp_pi, e_w

    def _update_mi(self, mi_grad):
        self.mi_adam.update(mi_grad, self.mi_lr)

    def _update_sk(self, sk_grad):
        self.sk_adam.update(sk_grad, self.sk_lr)

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self, ir, t):

        transitions = self.buffer.sample(self, ir, self.batch_size,
                                         self.mi_r_scale, self.sk_r_scale, t)
        weights = np.ones_like(transitions['r']).copy()
        if ir:
            self.mi_r_history.extend(
                ((np.clip((self.mi_r_scale * transitions['m']), *(0, 1)) -
                  (1 if not self.mi_r_scale == 0 else 0)) *
                 transitions['m_w']).tolist())
            self.sk_r_history.extend(
                ((np.clip(self.sk_r_scale * transitions['s'], *(-1, 0))) *
                 1.00).tolist())
            self.gl_r_history.extend(self.r_scale * transitions['r'])

        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        transitions['w'] = weights.flatten().copy()  # note: ordered dict
        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]

        return transitions_batch

    def stage_batch(self, ir, t, batch=None):
        if batch is None:
            batch = self.sample_batch(ir, t)
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def run_mi(self, o_s):
        feed_dict = {self.o_tau_tf: o_s.copy()}
        neg_l = self.sess.run(self.main_ir.mi_tf, feed_dict=feed_dict)
        return neg_l

    def run_sk(self, o, z):
        feed_dict = {self.main_ir.o_tf: o, self.main_ir.z_tf: z}
        sk_r = self.sess.run(self.main_ir.sk_r_tf, feed_dict=feed_dict)
        return sk_r

    def train_mi(self, data, stage=True):
        mi, mi_grad = self._grads_mi(data)
        self._update_mi(mi_grad)
        self.mi_current = -mi.mean()
        return -mi.mean()

    def train_sk(self, o_s_batch, z_s_batch, stage=True):
        sk, sk_grad = self._grads_sk(o_s_batch, z_s_batch)
        self._update_sk(sk_grad)
        return -sk.mean()

    def train(self, t, stage=True):
        if not self.buffer.current_size == 0:
            if stage:
                self.stage_batch(ir=True, t=t)
            critic_loss, actor_loss, Q_grad, pi_grad, neg_logp_pi, e_w = self._grads(
            )
            self._update(Q_grad, pi_grad)
            self.et_r_history.extend(((np.clip(
                (self.et_r_scale * neg_logp_pi), *(-1, 0))) * e_w).tolist())
            return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self,
                        pretrain_weights,
                        mi_prioritization,
                        reuse=False):
        if self.sac:
            logger.info("Creating a SAC agent with action space %d x %s..." %
                        (self.dimu, self.max_u))
        else:
            logger.info("Creating a DDPG agent with action space %d x %s..." %
                        (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])
        batch_tf['w'] = tf.reshape(batch_tf['w'], [-1, 1])
        batch_tf['m'] = tf.reshape(batch_tf['m'], [-1, 1])
        batch_tf['s'] = tf.reshape(batch_tf['s'], [-1, 1])

        self.o_tau_tf = tf.placeholder(tf.float32,
                                       shape=(None, None, self.dimo))

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # intrinsic reward (ir) network for mutual information
        with tf.variable_scope('ir') as vs:
            if reuse:
                vs.reuse_variables()
            self.main_ir = self.create_discriminator(batch_tf,
                                                     net_type='ir',
                                                     **self.__dict__)
            vs.reuse_variables()

        # loss functions

        mi_grads_tf = tf.gradients(tf.reduce_mean(self.main_ir.mi_tf),
                                   self._vars('ir/state_mi'))
        assert len(self._vars('ir/state_mi')) == len(mi_grads_tf)
        self.mi_grads_vars_tf = zip(mi_grads_tf, self._vars('ir/state_mi'))
        self.mi_grad_tf = flatten_grads(grads=mi_grads_tf,
                                        var_list=self._vars('ir/state_mi'))
        self.mi_adam = MpiAdam(self._vars('ir/state_mi'),
                               scale_grad_by_procs=False)

        sk_grads_tf = tf.gradients(tf.reduce_mean(self.main_ir.sk_tf),
                                   self._vars('ir/skill_ds'))
        assert len(self._vars('ir/skill_ds')) == len(sk_grads_tf)
        self.sk_grads_vars_tf = zip(sk_grads_tf, self._vars('ir/skill_ds'))
        self.sk_grad_tf = flatten_grads(grads=sk_grads_tf,
                                        var_list=self._vars('ir/skill_ds'))
        self.sk_adam = MpiAdam(self._vars('ir/skill_ds'),
                               scale_grad_by_procs=False)

        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      self.clip_return if self.clip_pos_returns else np.inf)

        self.e_w_tf = batch_tf['e_w']

        if not self.sac:
            self.main.neg_logp_pi_tf = tf.zeros(1)

        target_tf = tf.clip_by_value(
            self.r_scale * batch_tf['r'] * batch_tf['r_w'] +
            (tf.clip_by_value(self.mi_r_scale * batch_tf['m'], *(0, 1)) -
             (1 if not self.mi_r_scale == 0 else 0)) * batch_tf['m_w'] +
            (tf.clip_by_value(self.sk_r_scale * batch_tf['s'], *(-1, 0))) *
            batch_tf['s_w'] +
            (tf.clip_by_value(self.et_r_scale * self.main.neg_logp_pi_tf,
                              *(-1, 0))) * self.e_w_tf +
            self.gamma * target_Q_pi_tf, *clip_range)

        self.td_error_tf = tf.stop_gradient(target_tf) - self.main.Q_tf
        self.errors_tf = tf.square(self.td_error_tf)
        self.errors_tf = tf.reduce_mean(batch_tf['w'] * self.errors_tf)
        self.Q_loss_tf = tf.reduce_mean(self.errors_tf)

        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')

        # polyak averaging
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        if pretrain_weights:
            load_weight(self.sess, pretrain_weights, ['state_mi'])
            if self.finetune_pi:
                load_weight(self.sess, pretrain_weights, ['main'])

        self._sync_optimizers()
        if pretrain_weights and self.finetune_pi:
            load_weight(self.sess, pretrain_weights, ['target'])
        else:
            self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
        logs += [('mi_reward/mean', np.mean(self.mi_r_history))]
        logs += [('mi_reward/std', np.std(self.mi_r_history))]
        logs += [('mi_reward/max', np.max(self.mi_r_history))]
        logs += [('mi_reward/min', np.min(self.mi_r_history))]
        logs += [('mi_train/-neg_l', self.mi_current)]
        logs += [('sk_reward/mean', np.mean(self.sk_r_history))]
        logs += [('sk_reward/std', np.std(self.sk_r_history))]
        logs += [('sk_reward/max', np.max(self.sk_r_history))]
        logs += [('sk_reward/min', np.min(self.sk_r_history))]
        logs += [('et_reward/mean', np.mean(self.et_r_history))]
        logs += [('et_reward/std', np.std(self.et_r_history))]
        logs += [('et_reward/max', np.max(self.et_r_history))]
        logs += [('et_reward/min', np.min(self.et_r_history))]
        logs += [('gl_reward/mean', np.mean(self.gl_r_history))]
        logs += [('gl_reward/std', np.std(self.gl_r_history))]
        logs += [('gl_reward/max', np.max(self.gl_r_history))]
        logs += [('gl_reward/min', np.min(self.gl_r_history))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'sample_transitions', 'stage_shapes',
            'create_actor_critic', 'create_discriminator', '_history'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None
        if 'env_name' not in state:
            state['env_name'] = 'FetchPickAndPlace-v1'
        if 'network_class_discriminator' not in state:
            state[
                'network_class_discriminator'] = 'baselines.her.discriminator:Discriminator'
        if 'mi_r_scale' not in state:
            state['mi_r_scale'] = 1
        if 'mi_lr' not in state:
            state['mi_lr'] = 0.001
        if 'sk_r_scale' not in state:
            state['sk_r_scale'] = 1
        if 'sk_lr' not in state:
            state['sk_lr'] = 0.001
        if 'et_r_scale' not in state:
            state['et_r_scale'] = 1
        if 'finetune_pi' not in state:
            state['finetune_pi'] = None
        if 'no_train_mi' not in state:
            state['no_train_mi'] = None
        if 'load_weight' not in state:
            state['load_weight'] = None
        if 'pretrain_weights' not in state:
            state['pretrain_weights'] = None
        if 'mi_prioritization' not in state:
            state['mi_prioritization'] = None
        if 'sac' not in state:
            state['sac'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
示例#7
0
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 sample_transitions,
                 gamma,
                 temperature,
                 prioritization,
                 env_name,
                 alpha,
                 beta0,
                 beta_iters,
                 eps,
                 max_timesteps,
                 rank_method,
                 reuse=False,
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        self.prioritization = prioritization
        self.env_name = env_name
        self.temperature = temperature
        self.rank_method = rank_method

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        stage_shapes['w'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.T if key != 'o' else self.T + 1, *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T + 1, self.dimg)
        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size

        if self.prioritization == 'entropy':
            self.buffer = ReplayBufferEntropy(buffer_shapes, buffer_size,
                                              self.T, self.sample_transitions,
                                              self.prioritization,
                                              self.env_name)
        elif self.prioritization == 'tderror':
            self.buffer = PrioritizedReplayBuffer(buffer_shapes, buffer_size,
                                                  self.T,
                                                  self.sample_transitions,
                                                  alpha, self.env_name)
            if beta_iters is None:
                beta_iters = max_timesteps
            self.beta_schedule = LinearSchedule(beta_iters,
                                                initial_p=beta0,
                                                final_p=1.0)
        else:
            self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                       self.sample_transitions)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)

        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def get_td_errors(self, o, g, u):
        o, g = self._preprocess_og(o, g, g)
        vals = [self.td_error_tf]
        r = np.ones((o.reshape(-1, self.dimo).shape[0], 1))

        feed = {
            self.target.o_tf: o.reshape(-1, self.dimo),
            self.target.g_tf: g.reshape(-1, self.dimg),
            self.bath_tf_r: r,
            self.main.o_tf: o.reshape(-1, self.dimo),
            self.main.g_tf: g.reshape(-1, self.dimg),
            self.main.u_tf: u.reshape(-1, self.dimu)
        }
        td_errors = self.sess.run(vals, feed_dict=feed)
        td_errors = td_errors.copy()

        return td_errors

    def fit_density_model(self):
        self.buffer.fit_density_model()

    def store_episode(self,
                      episode_batch,
                      dump_buffer,
                      rank_method,
                      epoch,
                      update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """
        if self.prioritization == 'tderror':
            self.buffer.store_episode(episode_batch, dump_buffer)
        elif self.prioritization == 'entropy':
            self.buffer.store_episode(episode_batch, rank_method, epoch)
        else:
            self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)

            if self.prioritization == 'entropy':
                if not self.buffer.current_size == 0 and not len(
                        episode_batch['ag']) == 0:
                    transitions = self.sample_transitions(
                        episode_batch, num_normalizing_transitions, 'none',
                        1.0, True)
            elif self.prioritization == 'tderror':
                transitions, weights, episode_idxs = \
                self.sample_transitions(self.buffer, episode_batch, num_normalizing_transitions, beta=0)
            else:
                transitions = self.sample_transitions(
                    episode_batch, num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def dump_buffer(self, epoch):
        self.buffer.dump_buffer(epoch)

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad, td_error = self.sess.run([
            self.Q_loss_tf, self.main.Q_pi_tf, self.Q_grad_tf, self.pi_grad_tf,
            self.td_error_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad, td_error

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self, t):

        if self.prioritization == 'entropy':
            transitions = self.buffer.sample(self.batch_size,
                                             self.rank_method,
                                             temperature=self.temperature)
            weights = np.ones_like(transitions['r']).copy()
        elif self.prioritization == 'tderror':
            transitions, weights, idxs = self.buffer.sample(
                self.batch_size, beta=self.beta_schedule.value(t))
        else:
            transitions = self.buffer.sample(self.batch_size)
            weights = np.ones_like(transitions['r']).copy()

        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        transitions['w'] = weights.flatten().copy()  # note: ordered dict
        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]

        if self.prioritization == 'tderror':
            return (transitions_batch, idxs)
        else:
            return transitions_batch

    def stage_batch(self, t, batch=None):
        if batch is None:
            if self.prioritization == 'tderror':
                batch, idxs = self.sample_batch(t)
            else:
                batch = self.sample_batch(t)
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

        if self.prioritization == 'tderror':
            return idxs

    def train(self, t, dump_buffer, stage=True):
        if not self.buffer.current_size == 0:
            if stage:
                if self.prioritization == 'tderror':
                    idxs = self.stage_batch(t)
                else:
                    self.stage_batch(t)
            critic_loss, actor_loss, Q_grad, pi_grad, td_error = self._grads()
            if self.prioritization == 'tderror':
                new_priorities = np.abs(td_error) + self.eps  # td_error

                if dump_buffer:
                    T = self.buffer.buffers['u'].shape[1]
                    episode_idxs = idxs // T
                    t_samples = idxs % T
                    batch_size = td_error.shape[0]
                    with self.buffer.lock:
                        for i in range(batch_size):
                            self.buffer.buffers['td'][episode_idxs[i]][
                                t_samples[i]] = td_error[i]

                self.buffer.update_priorities(idxs, new_priorities)
            self._update(Q_grad, pi_grad)
            return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])
        batch_tf['w'] = tf.reshape(batch_tf['w'], [-1, 1])

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)

        self.td_error_tf = tf.stop_gradient(target_tf) - self.main.Q_tf
        self.errors_tf = tf.square(self.td_error_tf)
        self.errors_tf = tf.reduce_mean(batch_tf['w'] * self.errors_tf)
        self.Q_loss_tf = tf.reduce_mean(self.errors_tf)
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None
        state['env_name'] = None  # No need for playing the policy

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
示例#8
0
class DDPG(object):
    @store_args
    def __init__(self, FLAGS, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
                 Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
                 rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
                 bc_loss, q_filter, num_demo, demo_batch_size, prm_loss_weight, aux_loss_weight,
                 
                #  sample_transitions, gamma, reuse=False, **kwargs):
                 sample_transitions, gamma, td3_policy_freq, td3_policy_noise, td3_noise_clip, reuse=False, *agent_params, **kwargs): ##
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
            Added functionality to use demonstrations for training to Overcome exploration problem.

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
            bc_loss: whether or not the behavior cloning loss should be used as an auxilliary loss
            q_filter: whether or not a filter on the q value update should be used when training with demonstartions
            num_demo: Number of episodes in to be used in the demonstration buffer
            demo_batch_size: number of samples to be used from the demonstrations buffer, per mpi thread
            prm_loss_weight: Weight corresponding to the primary loss
            aux_loss_weight: Weight corresponding to the auxilliary loss also called the cloning loss

            agent_params: for HAC agent params
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        # self.dimo1= self.input_dims['o1'] ##A.R add for TD3 (has obs0, obs1)
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']
        #추가된 내용
        #parameters for using TD3 variant of DDPG
        #https://arxiv.org/abs/1802.09477
        self.td3_policy_freq = td3_policy_freq
        self.td3_policy_noise = td3_policy_noise
        self.td3_noise_clip = td3_noise_clip

        ## for HAC
        self.FLAGS = FLAGS

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
        # for key in ['o', 'o1', 'g']: #o1 added by A.R
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None,)
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {key: (self.T-1 if key != 'o' else self.T, *input_shapes[key]) 
        # origin : buffer_shapes = {key: (self.T-1 if key != 'o' else self.T, *input_shapes[key]) 
        # buffer_shapes = {key: (self.T-1 if key != 'o' and key != 'o1' else self.T, *input_shapes[key]) #A.Rㅇ
                         for key, val in input_shapes.items()}
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T, self.dimg)

        buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)

        global DEMO_BUFFER
        DEMO_BUFFER = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions) #initialize the demo buffer; in the same way as the primary data buffer
        print("@ ddgp.py , buffer={}".format(self.buffer))
        
        # self.meta_controller = DDPG(self.dimo + self.dimg, self.dimo, self.clip_obs)
        # ##
        # self.low_replay_buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)
        # self.high_replay_buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)
        # ##

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
    # def _preprocess_og(self, o, o1, ag, g): #A.R
        if self.relative_goals: ## goal reshape 해주는 곳. ag vs g..흠
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag) #상대적인 골로 만들어 주는구나?..
            '''
            def simple_goal_subtract(a, b):
            assert a.shape == b.shape
            return a - b
            '''
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        # o1 = np.clip(o1, -self.clip_obs, self.clip_obs) #A.R
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        # return o, o1, g
        return o, g

    def step(self, obs):
        # FLAGS = FLAGS
        actions = self.get_actions(obs['observation'], obs['achieved_goal'], obs['desired_goal'])
        # actions = self.get_actions(obs['observation'], obs['achieved_goal'], obs['desired_goal'], FLAGS)
        # print("for debug, obs : {}".format(obs['observation']))
        return actions, None, None, None


    # def get_actions(self, o, o1, ag, g, noise_eps=0., random_eps=0., use_target_net=False, ##o1이 target 네트워크
    def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
    # def get_actions(self, o, ag, g, FLAGS, noise_eps=0., random_eps=0., use_target_net=False,
                    compute_Q=False):
        # o, o1, g = self._preprocess_og(o, o1, ag, g) ##
        
        o, g = self._preprocess_og(o, ag, g) 
        policy = self.target if use_target_net else self.main # rollout.py에서 넘어온다.
        # values to compute
        vals = [policy.pi_tf]
        
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf: o.reshape(-1, self.dimo),
            policy.g_tf: g.reshape(-1, self.dimg),
            policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(*u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def init_demo_buffer(self, demoDataFile, update_stats=True): #function that initializes the demo buffer

        demoData = np.load(demoDataFile) #load the demonstration data from data file
        info_keys = [key.replace('info_', '') for key in self.input_dims.keys() if key.startswith('info_')]
        info_values = [np.empty((self.T - 1, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys]

        demo_data_obs = demoData['obs']
        demo_data_acs = demoData['acs']
        demo_data_info = demoData['info']

        for epsd in range(self.num_demo): # we initialize the whole demo buffer at the start of the training
            obs, acts, goals, achieved_goals = [], [] ,[] ,[]
            i = 0
            for transition in range(self.T - 1):
                obs.append([demo_data_obs[epsd][transition].get('observation')])
                acts.append([demo_data_acs[epsd][transition]])
                goals.append([demo_data_obs[epsd][transition].get('desired_goal')])
                achieved_goals.append([demo_data_obs[epsd][transition].get('achieved_goal')])
                for idx, key in enumerate(info_keys):
                    info_values[idx][transition, i] = demo_data_info[epsd][transition][key]


            obs.append([demo_data_obs[epsd][self.T - 1].get('observation')])
            achieved_goals.append([demo_data_obs[epsd][self.T - 1].get('achieved_goal')])

            episode = dict(o=obs,
                           u=acts,
                           g=goals,
                           ag=achieved_goals)
            for key, value in zip(info_keys, info_values):
                episode['info_{}'.format(key)] = value

            episode = convert_episode_to_batch_major(episode)
            global DEMO_BUFFER
            DEMO_BUFFER.store_episode(episode) # create the observation dict and append them into the demonstration buffer
            logger.debug("Demo buffer size currently ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size

            if update_stats:
                # add transitions to normalizer to normalize the demo data as well
                episode['o_2'] = episode['o'][:, 1:, :]
                episode['ag_2'] = episode['ag'][:, 1:, :]
                num_normalizing_transitions = transitions_in_episode_batch(episode)
                transitions = self.sample_transitions(episode, num_normalizing_transitions)

                o, g, ag = transitions['o'], transitions['g'], transitions['ag']
                transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
                # No need to preprocess the o_2 and g_2 since this is only used for stats

                self.o_stats.update(transitions['o'])
                self.g_stats.update(transitions['g'])

                self.o_stats.recompute_stats()
                self.g_stats.recompute_stats()
            episode.clear()

        logger.info("Demo buffer size: ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
            transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)

            o, g, ag = transitions['o'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf,
            self.main.Q_pi_tf,
            self.Q_grad_tf,
            self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        if self.bc_loss: #use demonstration buffer to sample as well if bc_loss flag is set TRUE
            transitions = self.buffer.sample(self.batch_size - self.demo_batch_size)
            global DEMO_BUFFER
            transitions_demo = DEMO_BUFFER.sample(self.demo_batch_size) #sample from the demo buffer
            for k, values in transitions_demo.items():
                rolloutV = transitions[k].tolist()
                for v in values:
                    rolloutV.append(v.tolist())
                transitions[k] = np.array(rolloutV)
        else:
            transitions = self.buffer.sample(self.batch_size) #otherwise only sample from primary buffer

        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        # o1, o1_2, g = transitions['o1'], transitions['o1_2'] ## A.R
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g)

        transitions_batch = [transitions[key] for key in self.stage_shapes.keys()]
        print("@ ddpg, sample_batch, transitions_batch={}".format(transitions_batch))
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads() ## 현재 loss들 가져오는거
        self._update(Q_grad, pi_grad) ## 아담 업데이트 하는거
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target1_net_op)
        self.sess.run(self.init_target2_net_op)

    def update_target_net(self):
        # self.sess.run(self.update_target_net_op)
        self.sess.run(self.update_target1_net_op)
        self.sess.run(self.update_target2_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope)
        assert len(res) > 0 #######################이게 왜걸리지? 왜 다시 안걸리지?
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope)
        # print("DEBUG, {}".format(res))
        
        return res

    def _create_network(self, reuse=False): ## num_demo 추가 -2
        logger.info("Debug : Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))
        self.sess = tf_util.get_session()
        # self.num_demo = num_demo

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get() ## 그냥 꺼내오는거..
        batch_tf = OrderedDict([(key, batch[i])
                                for i, key in enumerate(self.stage_shapes.keys())])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        #choose only the demo buffer samples
        mask = np.concatenate((np.zeros(self.batch_size - self.demo_batch_size), np.ones(self.demo_batch_size)), axis = 0)

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__)
            vs.reuse_variables()
            print("tf.variable_scope(main) = {}".format(tf.variable_scope('target1'))) #-1

        with tf.variable_scope('target1') as vs:
            if reuse:
                vs.reuse_variables()
            target1_batch_tf = batch_tf.copy()
            target1_batch_tf['o'] = batch_tf['o_2']
            target1_batch_tf['g'] = batch_tf['g_2']
            self.target1 = self.create_actor_critic(
                target1_batch_tf, net_type='target1', **self.__dict__)
            vs.reuse_variables()
            print("tf.variable_scope(target1) = {}".format(tf.variable_scope('target1')))
            # print("batch= {}".format(target1_batch_tf))
            # print(type('target')) #<class 'baselines.her.actor_critic.ActorCritic'>
        assert len(self._vars("main")) == len(self._vars("target1"))

        with tf.variable_scope('target2') as vs:
            if reuse:
                vs.reuse_variables()
            target2_batch_tf = batch_tf.copy()
            target2_batch_tf['o'] = batch_tf['o_2']
            target2_batch_tf['g'] = batch_tf['g_2']
            self.target2 = self.create_actor_critic(
                target2_batch_tf, net_type='target2', **self.__dict__)
            vs.reuse_variables()
            print("tf.variable_scope(target2) = {}".format(tf.variable_scope('target2')))
            print("batch= {}".format(target2_batch_tf))
        assert len(self._vars("main")) == len(self._vars("target2"))

        for nd in range(self.num_demo):       

            ##A.R
            ##Compute the target Q value, Q1과 Q2중에 min값을 사용한다.

            target1_Q_pi_tf = self.target1.Q_pi_tf ##A.R policy training
            target2_Q_pi_tf = self.target2.Q_pi_tf ##A.R
            # target_Q_pi_tf = tf.minimum(target1_Q_pi_tf, target2_Q_pi_tf)
            # target1_Q_tf = self.target1.Q_tf ##A.R policy training
            # target2_Q_tf = self.target2.Q_tf ##A.R
            # print('target1={}/////target2={}'.format(target1_Q_tf,target2_Q_tf))
            target_Q_pi_tf = tf.minimum(target1_Q_pi_tf, target2_Q_pi_tf)
            # target_Q_tf = tf.minimum(target1_Q_tf, target2_Q_tf) ## 대체 코드

            # print("{}///{}///{}".format(target1_Q_pi_tf,target2_Q_pi_tf,tf.minimum(target1_Q_pi_tf, target2_Q_pi_tf)))
            ####
            #TD3에서 빠진 코드 :target_Q = reward + (done * discount * target_Q).detach()(L109) ->L428에서 해주고 clip한다

            # loss functions
            # for policy training, Q_pi_tf = nn(input_Q, [self.hidden] * self.layers + [1])
            # target_Q_pi_tf = self.target.Q_pi_tf #original code
            clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
            target_Q_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
            # target_Q_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_tf, *clip_range) ## 대체 코드
            # self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
            ##
            # current_Q1, current_Q2 = self.critic(state, action)

            # for critic training, Q_tf = nn(input_Q, [self.hidden] * self.layers + [1], reuse=True)
            # target_Q_pi_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_tf, *clip_range) #original code
            
            # self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf)) #critic taining 
            
            ## Get current Q estimates, for critic Q
            current_Q1 = self.main.Q_tf ##A.R
            current_Q2 = self.main.Q_tf
            # print("Q1={}".format(current_Q1))

            ## Compute critic loss
            ## Torch => critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 
            self.Q_loss_tf = tf.losses.mean_squared_error(current_Q1, target_Q_tf)+ tf.losses.mean_squared_error(current_Q2,target_Q_tf)
            # self.Q_loss_tf = tf.losses.mean_squared_error(current_Q1, target_Q_tf)+ tf.losses.mean_squared_error(current_Q2,target_Q_tf)
            # print("critic_loss ={}".format(self.Q_loss_tf))

            Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
            assert len(self._vars('main/Q')) == len(Q_grads_tf)

            ## Optimize the critic 아담 옵티마이저
            self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
            assert len(self._vars('main/Q')) == len(Q_grads_tf)
            self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
            self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))

            # ## Delayed policy updates
            if nd % self.td3_policy_freq == 0:
                # print("num_demo = {}".format(nd))
                target1_Q_pi_tf = self.target1.Q_pi_tf ##A.R policy training
                target2_Q_pi_tf = self.target2.Q_pi_tf ##A.R
                tf.print(target1_Q_pi_tf, [target1_Q_pi_tf])
                tf.print(target2_Q_pi_tf, [target2_Q_pi_tf])
                # print(target2_Q_pi_tf)
                target_Q_pi_tf = tf.minimum(target1_Q_pi_tf, target2_Q_pi_tf)

                # target_Q_pi_tf = self.target.Q_pi_tf
                clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
                target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
                self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
                # Compute actor loss
                if self.bc_loss ==1 and self.q_filter == 1 : # train with demonstrations and use bc_loss and q_filter both
                    maskMain = tf.reshape(tf.boolean_mask(self.main.Q_tf > self.main.Q_pi_tf, mask), [-1]) #where is the demonstrator action better than actor action according to the critic? choose those samples only
                    #define the cloning loss on the actor's actions only on the samples which adhere to the above masks
                    self.cloning_loss_tf = tf.reduce_sum(tf.square(tf.boolean_mask(tf.boolean_mask((self.main.pi_tf), mask), maskMain, axis=0) - tf.boolean_mask(tf.boolean_mask((batch_tf['u']), mask), maskMain, axis=0)))
                    self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(self.main.Q_pi_tf) #primary loss scaled by it's respective weight prm_loss_weight
                    self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u)) #L2 loss on action values scaled by the same weight prm_loss_weight
                    self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf #adding the cloning loss to the actor loss as an auxilliary loss scaled by its weight aux_loss_weight

                elif self.bc_loss == 1 and self.q_filter == 0: # train with demonstrations without q_filter
                    self.cloning_loss_tf = tf.reduce_sum(tf.square(tf.boolean_mask((self.main.pi_tf), mask) - tf.boolean_mask((batch_tf['u']), mask)))
                    self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(self.main.Q_pi_tf)
                    self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
                    self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf

                else: #If  not training with demonstrations
                    self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
                    self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
                # self.pi_loss_tf = -tf.reduce_mean(self.main.pi_tf) ## what about target1?
                # self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
                # actor_loss = -tf.reduce_mean(self.main.Q_tf)
                # actor_loss += self.action_l2 * tf.reduce_mean(tf.square(self.main.Q_tf / self.max_u))

                pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
                assert len(self._vars('main/pi')) == len(pi_grads_tf)

                # Optimize the actor 
                # Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
                self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)
                assert len(self._vars('main/pi')) == len(pi_grads_tf)
                self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
                self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))

                # Update the frozen target models
            ## torch code
                # for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                #     target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
                

                self.main_vars = self._vars('main/Q') + self._vars('main/pi')
                self.target1_vars = self._vars('target1/Q') + self._vars('target1/pi') ##A.R
                self.target2_vars = self._vars('target2/Q') + self._vars('target2/pi') ##A.R
                if target_Q_pi_tf == target1_Q_pi_tf:
                    target_vars = self.target1_vars
                else:
                    target_vars = self.target2_vars
                # self.target_vars = self._vars('target/Q') + self._vars('target/pi') #original
                self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
                self.init_target1_net_op = list(
                    map(lambda v: v[0].assign(v[1]), zip(self.target1_vars, self.main_vars)))
                self.init_target2_net_op = list(
                    map(lambda v: v[0].assign(v[1]), zip(self.target2_vars, self.main_vars)))

                self.update_target_net_op = list(
                    map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(target_vars, self.main_vars)))
                self.update_target1_net_op = list(
                    map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(target_vars, self.main_vars)))
                self.update_target2_net_op = list(
                    map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(target_vars, self.main_vars)))


                tf.variables_initializer(self._global_vars('')).run()
                self._sync_optimizers()
                self._init_target_net()



        # Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        # pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        # assert len(self._vars('main/Q')) == len(Q_grads_tf)
        # assert len(self._vars('main/pi')) == len(pi_grads_tf)
        # self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        # self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        # self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
        # self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))

        # optimizers
        # self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        # self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)

        # polyak averaging
        # self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        # self.target1_vars = self._vars('target1/Q') + self._vars('target1/pi') ##A.R
        # self.target2_vars = self._vars('target2/Q') + self._vars('target2/pi') ##A.R
        # # self.target_vars = self._vars('target/Q') + self._vars('target/pi') #original
        # self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
        # self.init_target1_net_op = list(
        #     map(lambda v: v[0].assign(v[1]), zip(self.target1_vars, self.main_vars)))
        # self.init_target2_net_op = list(
        #     map(lambda v: v[0].assign(v[1]), zip(self.target2_vars, self.main_vars)))


            
        # self.update_target_net_op = list(
        #     map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))

        #original
        # self.init_target_net_op = list(
        #     map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
        # self.update_target_net_op = list(
        #     map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))

        # # initialize all variables
        # tf.variables_initializer(self._global_vars('')).run()
        # self._sync_optimizers()
        # self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats',
                            #  'main', 'target', 'lock', 'env', 'sample_transitions', #original code
                            'main', 'target1', 'target2', 'lock', 'env', 'sample_transitions',
                             'stage_shapes', 'create_actor_critic']

        state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])}
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert(len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)

    def save(self, save_path):
        tf_util.save_variables(save_path)
示例#9
0
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 sample_transitions,
                 gamma,
                 reuse=False,
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'GHER.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """

        # # print("\n\n\n\n1--", input_dims, "\n2--", buffer_size, "\n3--", hidden,
        #         "\n4--", layers, "\n5--", network_class, "\n6--", polyak, "\n7--", batch_size,
        #          "\n8--", Q_lr, "\n9--", pi_lr, "\n10--", norm_eps, "\n11--", norm_clip,
        #          "\n12--", max_u, "\n13--", action_l2, "\n14--", clip_obs, "\n15--", scope, "\n16--", T,
        #          "\n17--", rollout_batch_size, "\n18--", subtract_goals, "\n19--", relative_goals,
        #          "\n20--", clip_pos_returns, "\n21--", clip_return,
        #          "\n22--", sample_transitions, "\n23--", gamma)
        """
         Example of parameter values ​​in the FetchReach-v1 run:
            Input_dims (dict of ints): {'o': 10, 'u': 4, 'g': 3, 'info_is_success': 1} (o, u, g are both input to the network)
            Buffer_size (int): 1E6 (total number of experience pool samples)
            Hidden (int): 256 (number of hidden layer neurons)
            Layers (int): 3 (three-layer neural network)
            Network_class (str): GHER.ActorCritic'
            Polyak (float): 0.95 (smooth parameter updated by target-Network)
            Batch_size (int): 256 (bulk size)
            Q_lr (float): 0.001 (learning rate)
            Pi_lr (float): 0.001 (learning rate)
            Norm_eps (float): 0.01 (to avoid data overflow)
            Norm_clip (float): 5 (norm_clip)
            Max_u (float): 1.0 (the range of the action is [-1.0, 1.0])
            Action_l2 (float): 1.0 (loss coefficient of the actor network)
            Clip_obs (float): 200 (obs is limited to (-200, +200))
            Scope (str): "ddpg" (scope named field used by tensorflow)
            T (int): 50 (the number of cycles of interaction)
            Rollout_batch_size (int): 2 (number of parallel rollouts per DDPG agent)
            Subtract_goals (function): A function that preprocesses the goal, with inputs a and b, and output a-b
            Relative_goals (boolean): False (true if the need for function subtract_goals processing for the goal)
            Clip_pos_returns (boolean): True (Do you need to eliminate the positive return)
            Clip_return (float): 50 (limit the range of return to [-clip_return, clip_return])
            Sample_transitions (function): The function returned by her. The parameters are defined by config.py
            Gamma (float): 0.98 (the discount factor used when Q network update)

            Where sample_transition comes from the definition of HER and is a key part
        """

        if self.clip_return is None:
            self.clip_return = np.inf

        # The creation of the network structure and calculation graph is done by the actor_critic.py file
        self.create_actor_critic = import_function(self.network_class)

        # Extract dimension
        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']  # 10
        self.dimg = self.input_dims['g']  # 4
        self.dimu = self.input_dims['u']  # 3
        # print("+++", input_shapes)    #  {'o': (10,), 'u': (4,), 'g': (3,), 'info_is_success': (1,)}

        # https://www.tensorflow.org/performance/performance_models
        # StagingArea provides simpler functionality and can be executed in parallel with other phases in the CPU and GPU.
        # Split the input pipeline into 3 separate parallel operations, and this is scalable to take advantage of large multi-core environments

        # Define the required storage variable. Suppose self.dimo=10, self.dimg=5, self.dimu=5
        # Then state_shapes={'o':(None, 10), 'g':(None, 5), 'u':(None:5)}
        # Add the variable used by the target network at the same time state_shapes={'o_2':(None, 10), 'g_2': (None, 5)}
        # Prepare staging area for feeding data to the model.

        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )  # Reward for scalar
        self.stage_shapes = stage_shapes
        # After executing self.stage_shapes =
        # OrderedDict([('g', (None, 3)), ('o', (None, 10)), ('u', (None, 4)), ('o_2', (None, 10) ), ('g_2', (None, 3)), ('r', (None,))])
        # Including g, o, u, target used in o_2, g_2 and reward r
        # Create network.
        # Create tf variables based on state_shape, including g, o, u, o_2, g_2, r
        # self.buffer_ph_tf = [<tf.Tensor 'ddpg/Placeholder:0' shape=(?, 3) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_1:0' shape=(?, 10) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_2:0' shape=(?, 4) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_3:0' shape=(?, 10) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_4:0' shape=(?, 3) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_5:0' shape=(?,) dtype=float32>]
        with tf.variable_scope(self.scope):
            # Create a StagingArea variable
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            # Create a Tensorflow variable placeholder
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            # Correspond to the tensorflow variable and the StagingArea variable
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)
            #
            self._create_network(reuse=reuse)

        # Experience pool related operations
        # When T = 50, after execution, buffer_shapes=
        #         {'o': (51, 10), 'u': (50, 4), 'g': (50, 3), 'info_is_success': (50, 1), 'ag': (51, 3)}
        # Note that a, g, u all record all the samples experienced in a cycle, so it is 50 dimensions, but o and ag need 1 more? ? ? ?
        buffer_shapes = {
            key: (self.T if key != 'o' else self.T + 1, *input_shapes[key])
            for key, val in input_shapes.items()
        }  #
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)  #
        buffer_shapes['ag'] = (self.T + 1, self.dimg)  #
        # print("+++", buffer_shapes)

        # buffer_size Is the length counted by sample
        # self.buffer_size=1E6  self.rollout_batch_size=2 buffer_size=1E6
        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

    def _random_action(self, n):
        """
            从 [-self.max_u, +self.max_u] Random sampling n actions
        """
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        """
            obs, goal, achieve_goal Pretreatment
            In case self.relative_goal=True, then goal = goal - achieved_goal
        """
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)  # Increase 1 dimension
            ag = ag.reshape(-1, self.dimg)  # Increase 1 dimension
            g = self.subtract_goals(g, ag)  # g = g - ag
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        """
            Select the action according to the self.main network, then add Gaussian noise, clip, epsilon-greedy operation, and output the processed action
        """
        # If self.relative_goal=True, then the goal is preprocessed. Otherwise only clip
        o, g = self._preprocess_og(o, ag, g)
        # After calling the function self._create_network of this class, the self.main network and the self.target network are created, both of which are ActorCritic objects.
        policy = self.target if use_target_net else self.main  # Select an action based on self.main
        # actor Network output action tensor
        vals = [policy.pi_tf]

        # print("+++")
        # print(vals.shape)

        # Enter the vals of the actor output into the critic network again, and get the output as Q_pi_tf
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # The construction of feed_dict, including obs, goal and action, as input to Actor and Critic
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }
        # Execute the current policy network, output ret. ret[0] for action, ret[1] for Q value
        ret = self.sess.run(vals, feed_dict=feed)

        # action postprocessing
        # Add Gaussian noise to Action. np.random.randn refers to sampling from a Gaussian distribution, the noise obeys Gaussian distribution
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)  # After adding noise clip

        # Perform epsilon-greedy operation, epsilon for random_eps
        # Np.random.binomial refers to the binomial distribution, the output is 0 or 1, and the probability of output is 1 is random_eps
        # If the binomial distribution outputs 0, then u+=0 is equivalent to no operation; if the output is 1, then u = u + (random_action - u) = random_action
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u
        #
        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True, verbose=False):
        """
            Episode_batch: array of batch_size x (T or T+1) x dim_key
                        'o' is of size T+1, others are of size T
             Call the store_episode function in replay_buffer to store samples for one sample period
             O_stats and g_stats update and store the mean and standard deviation of obs and goal, respectively, and update them regularly

        """

        # Episode_batch stores a sample of the cycle generated by generate_rollout in rollout.py
        # episode_batch is a dictionary, the keys include o, g, u, ag, info, and the values of the values are respectively
        # o (2, 51, 10), u (2, 50, 4), g (2, 50, 3), ag (2, 51, 3), info_is_success (2, 50, 1)
        # where the first dimension is the number of workers, and the second dimension is determined by the length of the cycle.

        self.buffer.store_episode(episode_batch, verbose=verbose)

        # Update the mean and standard deviation of o_stats and g_stats
        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch[
                'ag'][:, 1:, :]  # Extract next_obs and next_state
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)  # Convert period to total number of samples

            # Call the sampling function in sample_transitions
            # Episode_batch is a dictionary with key and element shapes respectively o (2, 51, 10) u (2, 50, 4) g (2, 50, 3) ag (2, 51, 3) info_is_success (2, 50, 1)
            #                                          o_2 (2, 50, 10)  ag_2 (2, 50, 3)
            # Num_normalizing_transitions=100, there are 2 workers, each worker contains 50 samples of 1 cycle
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            # The sampled samples are preprocessed and used to update the calculations o_stats and g_stats, defined in the Normalizer, for storing mean and std
            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        """
            Returns the number of samples in the current experience pool
        """
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        """
            Q_adam and pi_adam are operators for updating actor networks and critic networks.
        """
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        """
            Return loss function and gradient
            Q_loss_tf, main.Q_pi_tf, Q_grad_tf, pi_grad_tf are all defined in the _create_network function
        """

        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf,
            self.main.Q_pi_tf,
            self.Q_grad_tf,
            self.pi_grad_tf,
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        """
            Update main Actor and Critic network
             The updated op is defined in _create_network
        """
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        """
            Sampling is called by calling the sample function in replay_buffer.py , which is derived from the definition in her.py
             The returned sample consists of batch, which is used to build the feed_dict in the self.stage_batch function.
             Feed_dict will be used as input to select actions and update network parameters

             Calls to sample a batch, then preprocesses o and g. The key of the sample includes o, o_2, ag, ag_2, g
        """
        # Call sample and return transition to dictionary, key and val.shape
        # o (256, 10) u (256, 4) g (256, 3) info_is_success (256, 1) ag (256, 3) o_2 (256, 10) ag_2 (256, 3) r (256,)
        # print("In DDPG: ", self.batch_size)
        transitions = self.buffer.sample(self.batch_size)
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))
        # tensorboard visualization
        self.tfboard_sample_batch = batch
        self.tfboard_sample_tf = self.buffer_ph_tf

    def train(self, stage=True):
        """
            Calculate the gradient and then update
             Self.stage_batch was executed before the parameter update was executed in the train to build the feed_dict used for training. This function is called.
                     The self.sample_batch function, which in turn calls self.buffer.sample, which calls config_her in config.py, which configures the parameters of her.py functions.
             The operators in train are defined in self._create_network .

        """
        if stage:
            self.stage_batch(
            )  # Returns a feed_dict constructed using the sampling method of her.py to calculate the gradient

        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        """
            Update the target network, update_target_net_op is defined in the function _create_network
        """
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        """
            Define the calculation flow graph required to calculate Actor and Critic losses
        """
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()
        # running averages
        # Define Normalizer objects for the rules obs and goal respectively
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        # Used to store the data structure of a batch sample, which is OrderedDict. After execution, batch_tf is as follows:
        # OrderedDict([('g', <tf.Tensor 'ddpg/ddpg/StagingArea_get:0' shape=(?, 3) dtype=float32>),
        # ('o', <tf.Tensor 'ddpg/ddpg/StagingArea_get:1' shape=(?, 10) dtype=float32>),
        # ('u', <tf.Tensor 'ddpg/ddpg/StagingArea_get:2' shape=(?, 4) dtype=float32>),
        # ('o_2', <tf.Tensor 'ddpg/ddpg/StagingArea_get:3' shape=(?, 10) dtype=float32>),
        # ('g_2', <tf.Tensor 'ddpg/ddpg/StagingArea_get:4' shape=(?, 3) dtype=float32>),
        # ('r', <tf.Tensor 'ddpg/Reshape:0' shape=(?, 1) dtype=float32>)])
        # Defined batch_tf variable will be used as input to the neural network

        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # Create main network according to ActorCritic.py
        # When creating an ActorCritic network, you don't need to explicitly pass arguments. Use self.__dict__ to assign the corresponding parameters of the DDPG class directly to the corresponding parameters of ActorCritic.
        # print(self.main.__dict__)
        # {'inputs_tf': OrderedDict([('g', <tf.Tensor 'ddpg/ddpg/StagingArea_get:0' shape=(?, 3) dtype=float32>), ('o', <tf.Tensor ' Ddpg/ddpg/StagingArea_get:1' shape=(?, 10) dtype=float32>), ('u', <tf.Tensor 'ddpg/ddpg/StagingArea_get:2' shape=(?, 4) dtype=float32> ), ('o_2', <tf.Tensor 'ddpg/ddpg/StagingArea_get:3' shape=(?, 10) dtype=float32>), ('g_2', <tf.Tensor 'ddpg/ddpg/StagingArea_get:4 ' shape=(?, 3) dtype=float32>), ('r', <tf.Tensor 'ddpg/Reshape:0' shape=(?, 1) dtype=float32>)]),
        # 'net_type': 'main', 'reuse': False, 'buffer_size': 1000000, 'hidden': 256, 'layers': 3, 'network_class': 'GHER.actor_critic:ActorCritic',
        # 'polyak': 0.95, 'batch_size': 256, 'Q_lr': 0.001, 'pi_lr': 0.001, 'norm_eps': 0.01, 'norm_clip': 5, 'max_u': 1.0,
        # 'action_l2': 1.0, 'clip_obs': 200.0, 'scope': 'ddpg', 'relative_goals': False, 'input_dims': {'o': 10, 'u': 4, 'g': 3, 'info_is_success': 1},
        # 'T': 50, 'clip_pos_returns': True, 'clip_return': 49.996, 'rollout_batch_size': 2, 'subtract_goals': <function simple_goal_subtract at 0x7fcf72caa510>, 'sample_transitions': <function make_sample_her_transitions.<locals>._sample_her_transitions at 0x7fcf6e2ce048>,
        # 'gamma': 0.98, 'info': {'env_name': 'FetchReach-v1'}, 'use_mpi': True, 'create_actor_critic': <class 'GHER.actor_critic.ActorCritic'>,
        # 'dimo': 10, 'dimg': 3, 'dimu': 4, 'stage_shapes': OrderedDict([('g', (None, 3)), ('o', (None, 10)), ('u', (None, 4)), ('o_2', (None, 10)), ('g_2', (None, 3)), ('r', (None,))]), ' Staging_tf': <tensorflow.python.ops.data_flow_ops.StagingArea object at 0x7fcf6e2dddd8>,
        # 'buffer_ph_tf': [<tf.Tensor 'ddpg/Placeholder:0' shape=(?, 3) dtype=float32>, <tf.Tensor 'ddpg/Placeholder_1:0' shape=(?, 10) dtype=float32 >, <tf.Tensor 'ddpg/Placeholder_2:0' shape=(?, 4) dtype=float32>, <tf.Tensor 'ddpg/Placeholder_3:0' shape=(?, 10) dtype=float32>, <tf .Tensor 'ddpg/Placeholder_4:0' shape=(?, 3) dtype=float32>, <tf.Tensor 'ddpg/Placeholder_5:0' shape=(?,) dtype=float32>],
        # 'stage_op': <tf.Operation 'ddpg/ddpg/StagingArea_put' type=Stage>, 'sess': <tensorflow.python.client.session.InteractiveSession object at 0x7fcf6e2dde10>, 'o_stats': <GHER.normalizer.Normalizer Object at 0x7fcf6e2ee940>, 'g_stats': <GHER.normalizer.Normalizer object at 0x7fcf6e2ee898>,
        # 'o_tf': <tf.Tensor 'ddpg/ddpg/StagingArea_get:1' shape=(?, 10) dtype=float32>, 'g_tf': <tf.Tensor 'ddpg/ddpg/StagingArea_get:0' shape=( ?, 3) dtype=float32>, 'u_tf': <tf.Tensor 'ddpg/ddpg/StagingArea_get:2' shape=(?, 4) dtype=float32>, 'pi_tf': <tf.Tensor 'ddpg/main /pi/mul:0' shape=(?, 4) dtype=float32>, 'Q_pi_tf': <tf.Tensor 'ddpg/main/Q/_3/BiasAdd:0' shape=(?, 1) dtype=float32 >, '_input_Q': <tf.Tensor 'ddpg/main/Q/concat_1:0' shape=(?, 17) dtype=float32>, 'Q_tf': <tf.Tensor 'ddpg/main/Q/_3_1/ BiasAdd: 0' shape=(?, 1) dtype=float32>}

        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()

        # O_2, g_2 is used to create target network
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf[
                'g_2']  # Since the target network is used to calculate the target-Q value, o and g need to use the value of the next state.
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        # To calculate Critic's target-Q value, you need to use the Actor's target network and Critic's target network.
        # target_Q_pi_tf uses the next state o_2 and g_2
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        # The loss function of Critic is the square of the difference between target_tf and Q_tf. Note that the gradient is not passed through target_tf.
        self.Q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))

        # The loss function of the Actor is the opposite of the Q value obtained by the actor's output in the main network.
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        # Add regulars to Actors
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))

        # Calculating the gradient
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(
            Q_grads_tf,
            self._vars('main/Q'))  # Gradient and variable name correspond
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars(
            'main/pi'
        )  # Put together the parameters of the Actor and Critic network
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(  # Target Initialization operation, the main network parameter is directly assigned to the target
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(  # In the target update operation, the main network and the target network need to be weighted according to the parameter polyak
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # # Tensorboard visualization
        # tf.summary.scalar("Q_target-Q-mean", tf.reduce_mean(target_tf))
        # tf.summary.histogram("Q_target-Q", target_tf)
        # tf.summary.scalar("Q_Td-error-mean", tf.reduce_mean(target_tf - self.main.Q_tf))
        # tf.summary.histogram("Q_Td-error", target_tf - self.main.Q_tf)
        # tf.summary.scalar("Q_reward-mean", tf.reduce_mean(batch_tf['r']))
        # tf.summary.histogram("Q_reward", batch_tf['r'])
        # tf.summary.scalar("Q_loss_tf", self.Q_loss_tf)
        # tf.summary.scalar("pi_loss_tf", self.pi_loss_tf)
        # self.merged = tf.summary.merge_all()

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def tfboard_func(self, summary_writer, step):
        """
            Tensorboard visualization
        """
        self.sess.run(self.stage_op,
                      feed_dict=dict(
                          zip(self.tfboard_sample_tf,
                              self.tfboard_sample_batch)))
        summary = self.sess.run(self.merged)
        summary_writer.add_summary(summary, global_step=step)

        print("S" + str(step), end=",")

    def __getstate__(self):
        """
            Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)

    # -----------------------------------------
    def updata_loss_all(self, verbose=False):
        assert self.buffer.current_size > 0
        idxes = np.arange(self.buffer.current_size)
        print("--------------------------------------")
        print("Updata All loss start...")
        self.buffer.update_rnnLoss(idxes, verbose=verbose)
        print("Updata All loss end ...")
示例#10
0
class Algorithm(object):
    @store_args
    def __init__(self,
                 buffer,
                 input_dims,
                 hidden,
                 layers,
                 polyak,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 gamma,
                 vloss_type='normal',
                 priority=False,
                 reuse=False,
                 **kwargs):
        """
        buffer (object): buffer to save transitions
        input_dims (dict of ints): dimensions for the observation (o), the goal (g), 
            and the actions (u)
        hidden (int): number of units in the hidden layers
        layers (int): number of hidden layers
        polyak (float): coefficient for Polyak-averaging of the target network
        Q_lr (float): learning rate for the Q (critic) network
        pi_lr (float): learning rate for the pi (actor) network
        norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
        norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
        max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
        action_l2 (float): coefficient for L2 penalty on the actions
        clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
        scope (str): the scope used for the TensorFlow graph
        subtract_goals (function): function that subtracts goals from each other
        relative_goals (boolean): whether or not relative goals should be fed into the network
        clip_pos_returns (boolean): whether or not positive returns should be clipped
        clip_return (float): clip returns to be in [-clip_return, clip_return]
        gamma (float): gamma used for Q learning updates
        vloss_type (str): value loss type, 'normal', 'tf_gamma', 'target'
        priority(boolean): use priority or not
        reuse (boolean): whether or not the networks should be reused
        """
        if self.clip_return is None:
            self.clip_return = np.inf
        self.dimo, self.dimg, self.dimu = self.input_dims[
            'o'], self.input_dims['g'], self.input_dims['u']
        self.stage_shapes = self.get_stage_shapes()
        self.init_target_net_op = None
        self.update_target_net_op = None

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)
            self._create_network(reuse=reuse)

        logger.log('value loss type: {}'.format(self.vloss_type))

    def get_stage_shapes(self):
        # Prepare staging area for feeding data to the model. save data for HER
        input_shapes = dims_to_shapes(self.input_dims)
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        if self.vloss_type == 'tf_gamma':
            stage_shapes['gamma'] = (None, )
        if self.priority:
            stage_shapes['w'] = (None, )
        return stage_shapes

    def _create_normalizer(self, reuse):
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('u_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.u_stats = Normalizer(self.dimu,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

    def _get_batch_tf(self):
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])
        if self.priority:
            batch_tf['w'] = tf.reshape(batch_tf['w'], [-1, 1])
        return batch_tf

    def _create_target_main(self, AC_class, reuse, batch_tf):
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = AC_class(batch_tf, self.dimo, self.dimg, self.dimu,
                                 self.max_u, self.o_stats, self.g_stats,
                                 self.hidden, self.layers, self.sess)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = AC_class(target_batch_tf, self.dimo, self.dimg,
                                   self.dimu, self.max_u, self.o_stats,
                                   self.g_stats, self.hidden, self.layers,
                                   self.sess)
            vs.reuse_variables()
        assert len(get_var(self.scope + "/main")) == len(
            get_var(self.scope + '/target'))

    def _clip_target(self, batch_tf, clip_range, target_V_tf):
        if self.vloss_type == 'tf_gamma':
            target_tf = tf.clip_by_value(
                batch_tf['r'] + batch_tf['gamma'] * target_V_tf, *clip_range)
        elif self.vloss_type == 'target':
            target_tf = tf.clip_by_value(batch_tf['r'], *clip_range)
        else:
            target_tf = tf.clip_by_value(
                batch_tf['r'] + self.gamma * target_V_tf, *clip_range)
        return target_tf

    def _create_network(self, reuse=False):
        raise NotImplementedError

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, g, ag=None):
        if self.relative_goals and ag:
            g_shape = g.shape
            g, ag = g.reshape(-1, self.dimg), ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def step(self, obs):  # act without noise
        actions = self.get_actions(obs['observation'], obs['achieved_goal'],
                                   obs['desired_goal'])
        return actions, None, None, None

    def simple_get_action(self, o, g, use_target_net=False):
        o, g = self._preprocess_og(o=o, g=g)
        policy = self.target if use_target_net else self.main  # in n-step self.target performs better
        action = self.sess.run(policy.pi_tf,
                               feed_dict={
                                   policy.o_tf: o.reshape(-1, self.dimo),
                                   policy.g_tf: g.reshape(-1, self.dimg)
                               })
        return action

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o=o, g=g, ag=ag)
        u = self.simple_get_action(o, g, use_target_net)
        if compute_Q:
            Q_pi = self.get_Q_fun(o, g)

        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u = np.clip(u + noise, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]

        if compute_Q:
            return [u, Q_pi]
        else:
            return u

    def get_Q_fun(self, o, g, u=None, Q_pi=True):
        o, g = self._preprocess_og(o, g)
        policy = self.target
        if Q_pi or (u is None):
            return policy.get_Q_pi(o, g)
        else:
            return policy.get_Q(o, g, u)

    def store_episode(self, episode_batch, update_stats=True):
        """episode_batch: array of batch_size x (T or T+1) x dim_key, 'o' is of size T+1, others are of size T"""
        self.buffer.store_episode(episode_batch)
        if update_stats:  # episode doesn't has key o_2
            os, gs, ags = episode_batch['o'].copy(), episode_batch['g'].copy(
            ), episode_batch['ag'].copy()
            os, gs = self._preprocess_og(o=os, g=gs, ag=ags)
            # update normalizer online
            self.o_stats.update_all(os)
            self.g_stats.update_all(gs)

    def _sync_optimizers(self):
        raise NotImplementedError

    def _grads(self):  # Avoid feed_dict here for performance!
        raise NotImplementedError

    def _update(self, Q_grad, pi_grad):
        raise NotImplementedError

    def stage_batch(self, batch=None):
        if batch is None:
            if self.priority:
                transitions, idxes = self.buffer.sample()
            else:
                transitions = self.buffer.sample()
            o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
            ag, ag_2 = transitions['ag'], transitions['ag_2']
            transitions['o'], transitions['g'] = self._preprocess_og(o=o,
                                                                     g=g,
                                                                     ag=ag)
            transitions['o_2'], transitions['g_2'] = self._preprocess_og(
                o=o_2, g=g, ag=ag_2)

            batch = [transitions[key] for key in self.stage_shapes.keys()]
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))
        if self.priority:
            return idxes

    def train(self, stage=True):
        if stage:
            idxes = self.stage_batch()
        critic_loss, actor_loss, Value_grad, pi_grad, abs_td_error = self._grads(
        )
        self._update(Value_grad, pi_grad)
        if self.priority:
            self.buffer.update_priorities(idxes, abs_td_error)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def logs_stats(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
        logs += [('stats_u/mean', np.mean(self.sess.run([self.u_stats.mean])))]
        logs += [('stats_u/std', np.mean(self.sess.run([self.u_stats.std])))]
        if prefix != '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def save(self, save_path):
        tf_util.save_variables(save_path)
示例#11
0
class DDPG(object):
    @store_args
    def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
                 Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
                 rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
                 sample_transitions, gamma, reuse=False, env=None, to_goal=None, nearby_action_penalty=False, nearby_penalty_weight=0,
                 sample_expert=False, expert_batch_size=0., bc_loss=0., anneal_bc=0., terminate_bootstrapping=False, mask_q = False,
                 two_qs=False, anneal_discriminator=False, **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        if self.clip_return is None:
            self.clip_return = np.inf
        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None,)
        if two_qs:
            stage_shapes['r2'] = (None,)
            stage_shapes['w_q2'] = (None, )
        stage_shapes['successes'] = (None,)
        if nearby_action_penalty:
            stage_shapes['far_from_goal'] = (None, )
        if sample_expert:
            stage_shapes['is_demo'] = (None, )
            stage_shapes['annealing_factor'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        # print(self.stage_shapes.keys())
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)
            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *input_shapes[key])
                         for key, val in input_shapes.items()}
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T+1, self.dimg)
        buffer_shapes['successes'] = (self.T,)


        buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)
        self.expert_buffer = None

        self.all_variables = self._global_vars('')

        if to_goal is None:
            print("to goal is none!")
            self.to_goal = (0, 2)
        else:
            self.to_goal = to_goal
        self.to_goal_func = (lambda x: x[self.to_goal[0] : self.to_goal[1]]) if len(self.to_goal) == 2 else (lambda x: x[np.array(self.to_goal)])
        self.nearby_action_penalty = nearby_action_penalty
        self.nearby_penalty_weight = nearby_penalty_weight


    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf: o.reshape(-1, self.dimo),
            policy.g_tf: g.reshape(-1, self.dimg),
            policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32),
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(*u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def get_action(self, o, noise=0.1):
        return self.get_actions([o], self.to_goal_func(o), self.env.current_goal, noise_eps=noise), None

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
            transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        if self.two_qs:
            self.Q2_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        if not self.two_qs:
            critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
                self.Q_loss_tf,
                self.main.Q_pi_tf,
                self.Q_grad_tf,
                self.pi_grad_tf
            ])
            return critic_loss, actor_loss, Q_grad, pi_grad
        else:
            critic_loss, critic_loss2, actor_loss, Q_grad, Q2_grad, pi_grad = self.sess.run([
                self.Q_loss_tf,
                self.Q2_loss_tf,
                self.main.Q_pi_tf,
                self.Q_grad_tf,
                self.Q2_grad_tf,
                self.pi_grad_tf
            ])
            return critic_loss, critic_loss2, actor_loss, Q_grad, Q2_grad, pi_grad

    def _update(self, Q_grad, pi_grad, Q2_grad=None):
        self.Q_adam.update(Q_grad, self.Q_lr)
        if self.two_qs:
            self.Q2_adam.update(Q2_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch_helper(self, buffer, batch_size, expert=False, annealing_factor=1., w_q2=1.):
        transitions = buffer.sample(batch_size)
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g)
        transitions['is_demo'] = int(expert) * np.ones_like(transitions['r']).astype(np.float32)
        transitions['annealing_factor'] = annealing_factor * np.ones_like(transitions['r']).astype(np.float32)
        if self.two_qs:
            transitions['w_q2'] = w_q2 * np.ones_like(transitions['r']).astype(np.float32)
        if self.anneal_discriminator:
            transitions['r'] = transitions['r'] + w_q2 * transitions['r2']
        transitions_batch = [transitions[key] for key in self.stage_shapes.keys()]
        return transitions_batch

    def sample_batch(self, annealing_factor=1., w_q2=1.):
        transitions_batch = None
        if self.batch_size > 0:
            transitions_batch = self.sample_batch_helper(self.buffer, self.batch_size, w_q2=w_q2)
        if self.sample_expert and self.expert_buffer is not None:
            expert_batch = self.sample_batch_helper(self.expert_buffer, self.expert_batch_size, expert=True, annealing_factor=annealing_factor, w_q2=w_q2)
            transitions_batch = expert_batch if transitions_batch is None else\
                [np.concatenate([normal, expert]) for (normal, expert) in zip(transitions_batch, expert_batch)]
        return transitions_batch

    def stage_batch(self, batch=None, annealing_factor=1., w_q2=1.):
        if batch is None:
            batch = self.sample_batch(annealing_factor=annealing_factor, w_q2=w_q2)
            # return goals that are trained on

        assert len(self.buffer_ph_tf) == len(batch)
        # if not (batch[5] <= batch[6]).all():
            # import pdb;
            # pdb.set_trace()
        #print(batch[5])

        self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))
        return batch

    def train(self, stage=True, annealing_factor=1., q_annealing=1.):
        if stage:
            batch = self.stage_batch(annealing_factor=annealing_factor, w_q2=q_annealing)
        if not self.two_qs:
            critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
            self._update(Q_grad, pi_grad)
        else:
            critic_loss, critic_loss2, actor_loss, Q_grad, Q2_grad, pi_grad = self._grads()
            self._update(Q_grad, pi_grad, Q2_grad=Q2_grad)

        return critic_loss, actor_loss, batch


    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        # logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats', reuse=reuse) as vs:
            # if reuse:
            #     vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess)
        with tf.variable_scope('g_stats', reuse=reuse) as vs:
            # if reuse:
            #     vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess)
        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([(key, batch[i])
                                for i, key in enumerate(self.stage_shapes.keys())])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])
        batch_tf['successes'] = tf.reshape(batch_tf['successes'], [-1, 1])
        # networks
        with tf.variable_scope('main', reuse=reuse) as vs:
            # if reuse:
            #     vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target', reuse=reuse) as vs:
            # if reuse:
            #     vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(
                target_batch_tf, net_type='target',  **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))
        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        if self.two_qs:
            target_Q2_pi_tf = self.target.Q2_pi_tf
        # clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
        clip_range = (-np.inf, self.clip_return)
        # print(clip_range)
        if self.terminate_bootstrapping:
            target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * (1 - batch_tf['successes']) * target_Q_pi_tf, *clip_range)
            if self.two_qs:
                target2_tf = tf.clip_by_value(batch_tf['r2'] + self.gamma * (1 - batch_tf['successes']) * target_Q2_pi_tf, *clip_range)
        else:
            target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
            if self.two_qs:
                target2_tf = tf.clip_by_value(batch_tf['r2'] + self.gamma * target_Q2_pi_tf, *clip_range)
        if self.nearby_action_penalty:
            target_tf -= tf.reshape(batch_tf['far_from_goal'] * self.nearby_penalty_weight * tf.norm(self.main.pi_tf - batch_tf['u'], axis=-1), (-1, 1))

        self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
        if self.two_qs:
            self.Q2_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target2_tf) - self.main.Q2_tf))

        if self.mask_q:
            self.pi_loss_tf = 0
        else:
            if self.two_qs:
                self.pi_loss_tf = -tf.reduce_mean((1 - batch_tf['w_q2'])[:, None] * self.main.Q_pi_tf + batch_tf['w_q2'][:, None] * self.main.Q2_pi_tf)
            else:
                self.pi_loss_tf =  -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf/ self.max_u))
        if self.sample_expert:
            self.pi_loss_tf += (1 - self.anneal_bc * tf.to_float(tf.greater_equal(self.target.Q_pi_tf, self.target.Q_tf))) * \
                self.bc_loss * tf.reduce_mean(batch_tf['is_demo'] * batch_tf['annealing_factor'] *
                tf.reduce_sum(tf.square(self.main.pi_tf - batch_tf['u']), axis=-1 ))

        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))

        if self.two_qs:
            Q2_grads_tf = tf.gradients(self.Q2_loss_tf, self._vars('main/2Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))

        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)

        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        if self.two_qs:
            self.Q2_grads_vars_tf = zip(Q2_grads_tf, self._vars('main/2Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))

        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
        if self.two_qs:
            self.Q2_grad_tf = flatten_grads(grads=Q2_grads_tf, var_list=self._vars('main/2Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        if self.two_qs:
            self.Q2_adam = MpiAdam(self._vars('main/2Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi') + (self._vars('main/2Q') if self.two_qs else [])
        self.target_vars = self._vars('target/Q') + self._vars('target/pi') + (self._vars('target/2Q') if self.two_qs else [])
        self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))
        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
        # import pdb; pdb.set_trace()
        self.stage_batch()
        # logs += [('action_diff', np.mean(self.sess.run([tf.norm(self.main.u_tf - self.main.pi_tf, axis=-1)])))]


        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    # to be compatible with rollout collection in rllab
    def reset(self):
        pass

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        # import pdb; pdb.set_trace()
        excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats',
                             'main', 'target', 'lock', 'env', 'sample_transitions',
                             'stage_shapes', 'create_actor_critic', 'all_variables', 'to_goal_func']

        state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])}
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run([x for x in self.all_variables if 'buffer' not in x.name])
        # print("global variables", tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
        # print("in get state")
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        # print(vars)
        # print(state['tf'])
        assert(len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
示例#12
0
class PGGD(object):

    DIMO = 0

    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 sample_transitions,
                 gamma,
                 reuse=False,
                 **kwargs):
        """Implementation of PGGD that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per PGGD agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        # ------------------
        # To access information of environment name and stuff
        self.kwargs = kwargs
        # ------------------

        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        # ----------------------
        input_shapes['o'] = (None, )
        # ----------------------

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )

        # ----------------------
        stage_shapes['G'] = (None, )
        # ----------------------

        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.T, *input_shapes[key]) if key != 'o' else
            (self.T + 1, PGGD.DIMO)
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T + 1, self.dimg)
        # -------------------
        buffer_shapes['G'] = (self.T, )
        buffer_shapes['sigma'] = (self.T, self.dimu)
        self.weight_path = None
        # -------------------
        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)

        return o, g

    # -------------------------------
    # If observation has more dimensions than what the policy takes in
    # then just truncate it.
    def get_actions(self, o, ag, g, exploit=False):
        # if len(o.shape) == 1:
        #     o = o[:self.dimo]
        #     g = g[:self.dimg]
        #     ag = ag[:self.dimg]
        # else:
        #     o = o[:,:self.dimo]
        #     g = g[:,:self.dimg]
        #     ag = ag[:,:self.dimg]
        o, g = self._preprocess_og(o, ag, g)
        policy = self.main
        # values to compute
        if exploit:
            vals = [policy.da_tf]
        else:
            vals = [policy.a_tf]

        vals += [policy.raw_tf, policy.sigma_tf]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u, raw, sigma = ret
        if u.shape[0] == 1:
            u = u[0]
            raw = raw[0]
            sigma = sigma[0]
        u = u.copy()
        raw = raw.copy()
        sigma = sigma.copy()
        return u, raw, sigma

    # -------------------------------

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats
            if 'Variation' in self.kwargs['info']['env_name']:
                o = transitions['o'][:, 1:]
                # o = np.concatenate([transitions['o'][:,:ENV_FEATURES],
                #                     transitions['o'][:,ENV_FEATURES+1:]], axis=1)
            else:
                o = transitions['o']

            self.o_stats.update(o)
            self.G_stats.update(transitions['G'])
            self.sigma_stats.update(transitions['sigma'])
            # self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            # self.g_stats.recompute_stats()
            self.G_stats.recompute_stats()
            self.sigma_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        pi_loss, pi_grad, mu = self.sess.run(
            [self.pi_loss_tf, self.pi_grad_tf, self.main.mu_tf])
        # print(np.mean(mu), np.mean(pi_grad), np.mean(pi_loss))
        return pi_loss, pi_grad

    def _update(self, pi_grad):
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        transitions = self.buffer.sample(self.batch_size)
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        # print(transitions['G'])
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        pi_loss, pi_grad = self._grads()
        self._update(pi_grad)
        # print(np.mean(pi_grad))
        return pi_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a PGGD agent with action space %d x %s..." %
                    (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            o_stats_dim = self.dimo
            if 'Variation' in self.kwargs['info']['env_name']:
                print("Found Variation in env name")
                o_stats_dim -= 1
            self.o_stats = Normalizer(o_stats_dim,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        # --------------
        with tf.variable_scope('G_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.G_stats = Normalizer(1,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('sigma_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.sigma_stats = Normalizer(self.dimu,
                                          self.norm_eps,
                                          self.norm_clip,
                                          sess=self.sess)
        # --------------
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])
        # ------------
        batch_tf['G'] = tf.reshape(batch_tf['G'], [
            -1,
        ])
        # ------------

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # ---------------------------
        # loss functions
        log_prob = tf.reduce_sum(tf.log(
            tf.clip_by_value(self.main.a_prob_tf, 1e-10, 1.0)),
                                 axis=1)
        neg_weighted_log_prob = -tf.multiply(batch_tf['G'], log_prob)
        self.pi_loss_tf = tf.reduce_mean(neg_weighted_log_prob)

        # https://github.com/tensorflow/tensorflow/issues/783
        def replace_none_with_zero(grads, var_list):
            return [
                grad if grad is not None else tf.zeros_like(var)
                for var, grad in zip(var_list, grads)
            ]

        pi_grads_tf = replace_none_with_zero(
            tf.gradients(self.pi_loss_tf, self._vars('main/pi')),
            self._vars('main/pi'))
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))
        # ---------------------------

        # optimizers
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        # self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        # self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats') + self._global_vars('G_stats') + self._global_vars(
                'sigma_stats')
        # self.init_target_net_op = list(
        #     map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
        # self.update_target_net_op = list(
        #     map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        # self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
        logs += [('stats_G/mean', np.mean(self.sess.run([self.G_stats.mean])))]
        logs += [('stats_G/std', np.mean(self.sess.run([self.G_stats.std])))]
        logs += [('stats_stddev/mean',
                  np.mean(self.sess.run([self.sigma_stats.mean])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def set_sample_transitions(self, fn):
        self.sample_transitions = fn
        self.buffer.sample_transitions = fn

    def set_obs_size(self, dims):
        self.input_dims = dims
        self.dimo = dims['o']
        self.dimg = dims['g']
        self.dimu = dims['u']

    def save_weights(self, path):
        self.main.save_weights(self.sess, path)
        self.weight_path = path

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        self.weight_path = state['weight_path']
        # Hard override...
        # This is due to the fact that the directory that the weights are saved to
        # might not be the same when it is loaded again
        # TODO: Delete this!!!!
        self.weight_path = "/Users/matt/RL/Results/5-3blocks-GPGGD-3-256/weights"
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [
            tf.no_op() if 'o_stats' in var.name else tf.assign(var, val)
            for var, val in zip(vars, state["tf"])
        ]
        self.sess.run(node)
        if self.weight_path != None:
            print("Reading weights for sure this time!")
            print(self.weight_path)
            print(tf.train.latest_checkpoint(self.weight_path))
            self.main.load_weights(self.sess, self.weight_path)
示例#13
0
文件: ddpg.py 项目: s-bl/cwyc
class DDPG(ParallelModule):
    def __init__(self,
                 env_spec,
                 task_spec,
                 buffer_size,
                 network_params,
                 normalizer_params,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 random_eps,
                 noise_eps,
                 train_steps,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 replay_strategy,
                 replay_k,
                 noise_type,
                 share_experience,
                 noise_adaptation,
                 reuse=False):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
            Added functionality to use demonstrations for training to Overcome exploration problem.

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
            bc_loss: whether or not the behavior cloning loss should be used as an auxilliary loss
            q_filter: whether or not a filter on the q value update should be used when training with demonstartions
            num_demo: Number of episodes in to be used in the demonstration buffer
            demo_batch_size: number of samples to be used from the demonstrations buffer, per mpi thread
            prm_loss_weight: Weight corresponding to the primary loss
            aux_loss_weight: Weight corresponding to the auxilliary loss also called the cloning loss
        """
        super().__init__(scope)
        self.replay_k = replay_k
        self.replay_strategy = replay_strategy
        self.clip_pos_returns = clip_pos_returns
        self.relative_goals = relative_goals
        self.train_steps = train_steps
        self.noise_eps = noise_eps
        self.random_eps = random_eps
        self.clip_obs = clip_obs
        self.action_l2 = action_l2
        self.max_u = max_u
        self.pi_lr = pi_lr
        self.Q_lr = Q_lr
        self.batch_size = batch_size
        self.normalizer_params = normalizer_params
        self.polyak = polyak
        self.buffer_size = buffer_size
        self._env_spec = env_spec
        self._T = self._env_spec['T']
        self._task_spec = task_spec
        self.network_params = network_params
        self._share_experience = share_experience
        self._noise_adaptation = noise_adaptation

        self._task_spec = deepcopy(task_spec)
        self._task_spec['buffer_size'] = 0
        self._task = Task(**self._task_spec)

        self._gamma = 1. - 1. / self._T
        self.clip_return = (1. / (1. - self._gamma)) if clip_return else np.inf

        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(network_params['net_type'])

        self.input_dims = dict(
            o=self._env_spec['o_dim'],
            a=self._env_spec['a_dim'],
            g=self._task_spec['g_dim'],
        )

        input_shapes = dims_to_shapes(self.input_dims)

        self.dimo = self._env_spec['o_dim']
        self.dimg = self._task_spec['g_dim']
        self.dima = self._env_spec['a_dim']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_next'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        self._action_noise, self._parameter_noise = get_noise_from_string(
            self._env_spec, noise_type)

        # Create network.
        with tf.variable_scope(self._scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        buffer_shapes = dict()
        buffer_shapes['o'] = (self.dimo, )
        buffer_shapes['o_next'] = buffer_shapes['o']
        buffer_shapes['g'] = (self.dimg, )
        buffer_shapes['ag'] = (self.dimg, )
        buffer_shapes['ag_next'] = (self.dimg, )
        buffer_shapes['a'] = (self.dima, )

        self.sample_transitions = make_sample_her_transitions(
            self.replay_strategy, self.replay_k,
            self._task.reward_done_success)

        self._buffer = ReplayBuffer(buffer_shapes, self.buffer_size, self._T,
                                    self.sample_transitions)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dima))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = simple_goal_subtract(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def step(self, obs):
        actions = self.get_actions(obs['observation'], obs['achieved_goal'],
                                   obs['desired_goal'])
        return actions, None, None, None

    def pre_rollout(self):
        if self._parameter_noise is not None:
            self.adapt_param_noise()
            self.sess.run(self.perturbe_op,
                          feed_dict={
                              self.param_noise_stddev:
                              self._parameter_noise.current_stddev
                          })

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=None,
                    random_eps=None,
                    use_target_net=False,
                    compute_Q=False,
                    success_rate=None,
                    mode=TRAIN):
        g = self._task.mg_fn(g)
        ag = self._task.mg_fn(ag)
        o, g = self._preprocess_og(o, ag, g)
        if mode == EVAL:
            policy = self.target if use_target_net else self.main
        else:
            if self._parameter_noise is not None:
                policy = self.perturbed
            else:
                policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.a_tf:
            np.zeros((o.size // self.dimo, self.dima), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        if mode == EVAL:
            u = np.clip(u, -self.max_u, self.max_u)
        else:
            if self._action_noise is not None:
                noise = self._action_noise()
                assert u.shape[0] == 1
                u = u[0]
                u += noise
                u = np.clip(u, -self.max_u, self.max_u)
            if self._parameter_noise is None and self._action_noise is None:
                if noise_eps is None: noise_eps = self.noise_eps
                if random_eps is None: random_eps = self.random_eps
                if self._noise_adaptation:
                    noise_eps *= 1 - success_rate
                    random_eps *= 1 - success_rate
                noise = noise_eps * self.max_u * np.random.randn(
                    *u.shape)  # gaussian noise
                u += noise
                u = np.clip(u, -self.max_u, self.max_u)
                u += np.random.binomial(1, random_eps, u.shape[0]).reshape(
                    -1, 1) * (self._random_action(u.shape[0]) - u
                              )  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_transitions(self, episode, info, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        new_episode = episode
        if not self._share_experience:
            new_episode = {
                key: [
                    np.asarray(value[i][np.squeeze(episode['tasks'][i]) ==
                                        self._task_spec['id']])
                    for i in range(len(value)) if np.any(
                        np.squeeze(episode['tasks'][i]) ==
                        self._task_spec['id'])
                ]
                for key, value in episode.items()
            }

            batch_sizes = [len(value) for value in new_episode.values()]
            assert np.all(np.array(batch_sizes) == batch_sizes[0])
            batch_size = batch_sizes[0]
            if batch_size == 0: return

        new_batch = deepcopy(
            dict(
                o=new_episode['o'],
                o_next=new_episode['o_next'],
                a=new_episode['a'],
                ag=[self._task.mg_fn(ag) for ag in new_episode['ag']],
                ag_next=[
                    self._task.mg_fn(ag_next)
                    for ag_next in new_episode['ag_next']
                ],
                g=[self._task.mg_fn(g) for g in new_episode['g']],
                g_next=[
                    self._task.mg_fn(g_next)
                    for g_next in new_episode['g_next']
                ],
                r=new_episode['r'],
            ))

        self._buffer.store_episode(new_batch)

        if update_stats:
            num_normalizing_transitions = transitions_in_episode_batch(
                new_batch)
            new_batch['ep_T'] = np.asarray(
                [ep.shape[0] for ep in list(new_batch.values())[0]])
            new_batch = {
                key: np.asarray(value)
                for key, value in new_batch.items()
            }
            transitions = self.sample_transitions(new_batch,
                                                  num_normalizing_transitions)

            o, g, ag = transitions['o'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self._buffer.get_current_size()

    def sample_batch(self):
        transitions = self._buffer.sample(
            self.batch_size)  #otherwise only sample from primary buffer

        o, o_next, g = transitions['o'], transitions['o_next'], transitions[
            'g']
        ag, ag_next = transitions['ag'], transitions['ag_next']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_next'], transitions['g_next'] = self._preprocess_og(
            o_next, ag_next, g)

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        Q_loss, pi_loss, Q, q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf, self.main.Q_pi_tf, self.target.Q_pi_tf,
            self.Q_grad_tf, self.pi_grad_tf
        ])
        return Q_loss, pi_loss, Q, q_grad, pi_grad

    def _update(self, q_grad, pi_grad):
        self.Q_adam.update(q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def train(self, stage=True):
        if self._buffer.get_current_size() == 0: return {}
        Q_losses = []
        pi_losses = []
        Qs = []
        for i in range(self.train_steps):
            if stage:
                self.stage_batch()
            Q_loss, pi_loss, Q, q_grads, pi_grads = self._grads()
            self._update(q_grads, pi_grads)
            # Q_loss, pi_loss, Q, q_grads, pi_grads, *_ = self.sess.run([self.Q_loss_tf, self.main.Q_pi_tf,
            #                                                             self.target.Q_pi_tf,
            #                                                             self.Q_grads_tf, self.pi_grads_tf,
            #                                                             self.Q_train_op, self.pi_train_op])
            Q_losses.append(Q_loss)
            pi_losses.append(np.mean(pi_loss))
            Qs.append(Q)
        self.update_target_net()
        return {
            'Q_loss': ('scalar', np.mean(Q_losses)),
            'pi_loss': ('scalar', np.mean(pi_losses)),
            'Q': ('hist', np.hstack(Qs)),
            'q_grads':
            ('hist', np.hstack([q_grad.reshape(-1) for q_grad in q_grads])),
            'pi_grads':
            ('hist', np.hstack([pi_grad.reshape(-1) for pi_grad in pi_grads])),
        }

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self._buffer.clear_buffer()

    def _create_network(self, reuse=False):
        self.sess = tf.get_default_session() or tf.InteractiveSession(
            config=tf.ConfigProto(inter_op_parallelism_threads=1,
                                  intra_op_parallelism_threads=1))

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(
                self.dimo,
                self.normalizer_params['eps'],
                self.normalizer_params['default_clip_range'],
                sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(
                self.dimg,
                self.normalizer_params['eps'],
                self.normalizer_params['default_clip_range'],
                sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        network_params = deepcopy(self.network_params)
        del network_params['net_type']
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **network_params,
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_next']
            target_batch_tf['g'] = batch_tf['g_next']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **network_params,
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self.vars("main")) == len(self.vars("target"))
        if self._parameter_noise is not None:
            with tf.variable_scope('perturbed') as vs:
                if reuse:
                    vs.reuse_variables()
                self.perturbed = self.create_actor_critic(batch_tf,
                                                          net_type='perturbed',
                                                          **network_params,
                                                          **self.__dict__)
                vs.reuse_variables()
            assert len(self.vars("main")) == len(self.vars("perturbed"))
            with tf.variable_scope('adaptive_param_noise') as vs:
                if reuse:
                    vs.reuse_variables()
                self.adaptive_param_noise = self.create_actor_critic(
                    batch_tf,
                    net_type='adaptive_param_noise',
                    **network_params,
                    **self.__dict__)
                vs.reuse_variables()
            assert len(self.vars("main")) == len(
                self.vars("adaptive_param_noise"))

            self.adaptive_param_noise_distance = tf.sqrt(
                tf.reduce_mean(
                    tf.square(self.main.pi_tf -
                              self.adaptive_param_noise.pi_tf)))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self._gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))

        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))

        self.Q_grads_tf = tf.gradients(self.Q_loss_tf, self.vars('main/Q'))
        self.pi_grads_tf = tf.gradients(self.pi_loss_tf, self.vars('main/pi'))
        assert len(self.vars('main/Q')) == len(self.Q_grads_tf)
        assert len(self.vars('main/pi')) == len(self.pi_grads_tf)
        self.Q_grads_vars_tf = zip(self.Q_grads_tf, self.vars('main/Q'))
        self.pi_grads_vars_tf = zip(self.pi_grads_tf, self.vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=self.Q_grads_tf,
                                       var_list=self.vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=self.pi_grads_tf,
                                        var_list=self.vars('main/pi'))

        self.Q_adam = MpiAdam(self.vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self.vars('main/pi'), scale_grad_by_procs=False)

        # self.Q_adam = tf.train.AdamOptimizer(learning_rate=self.Q_lr)
        # self.pi_adam = tf.train.AdamOptimizer(learning_rate=self.pi_lr)

        # self.Q_train_op = self.Q_adam.minimize(self.Q_loss_tf, var_list=self.vars('main/Q'))
        # self.pi_train_op = self.pi_adam.minimize(self.pi_loss_tf, var_list=self.vars('main/pi'))

        # polyak averaging
        self.main_vars = self.vars('main/Q') + self.vars('main/pi')
        self.target_vars = self.vars('target/Q') + self.vars('target/pi')
        self.stats_vars = self.global_vars('o_stats') + self.global_vars(
            'g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        self.param_noise_stddev = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='param_noise_stddev')

        # perturbe
        if self._parameter_noise is not None:
            self.perturbed_vars = self.vars('perturbed/Q') + self.vars(
                'perturbed/pi')
            self.perturbe_op = list(
                map(
                    lambda v: v[0].assign(v[1] + tf.random_normal(
                        tf.shape(v[1]),
                        mean=0.,
                        stddev=self.param_noise_stddev)),
                    zip(self.perturbed_vars, self.main_vars)))

            self.adaptive_param_noise_vars = self.vars(
                'adaptive_param_noise/Q') + self.vars(
                    'adaptive_param_noise/pi')
            self.adaptive_params_noise_perturbe_op = list(
                map(
                    lambda v: v[0].assign(v[1] + tf.random_normal(
                        tf.shape(v[1]),
                        mean=0.,
                        stddev=self.param_noise_stddev)),
                    zip(self.adaptive_param_noise_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self.global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

        self._sync_source_tf = [
            tf.placeholder(tf.float32, x.shape) for x in self.vars('')
        ]
        self._sync_op_tf = [
            target.assign(source)
            for target, source in zip(self.vars(''), self._sync_source_tf)
        ]

        self._global_sync_source_tf = [
            tf.placeholder(tf.float32, x.shape) for x in self.global_vars('')
        ]
        self._global_sync_op_tf = [
            target.assign(source) for target, source in zip(
                self.global_vars(''), self._global_sync_source_tf)
        ]

    def adapt_param_noise(self):

        if not self._buffer.get_current_size() > 0: return

        self.sess.run(self.adaptive_params_noise_perturbe_op,
                      feed_dict={
                          self.param_noise_stddev:
                          self._parameter_noise.current_stddev
                      })

        self.stage_batch()
        distance = self.sess.run(self.adaptive_param_noise_distance)

        self._parameter_noise.adapt(distance)

    def vars(self, scope=''):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self._scope + '/' + scope)
        if scope == '':
            res += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=self._scope + '/' + 'o_stats/mean')
            res += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=self._scope + '/' + 'o_stats/std')
            res += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=self._scope + '/' + 'g_stats/mean')
            res += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=self._scope + '/' + 'g_stats/std')
        assert len(res) > 0
        return res

    def global_vars(self, scope=''):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self._scope + '/' + scope)
        return res

    def get_params(self, scope=''):
        return (self._scope, [(str(var), self.sess.run(var))
                              for var in self.vars(scope)])

    def get_global_params(self, scope=''):
        return (self._scope, [(str(var), self.sess.run(var))
                              for var in self.global_vars(scope)])

    def set_params(self, params, scope=''):
        params = [param[1] for param in params]
        self.sess.run(self._sync_op_tf,
                      feed_dict=dict([
                          (ph, param)
                          for ph, param in zip(self._sync_source_tf, params)
                      ]))

    def set_global_params(self, params, scope=''):
        params = [param[1] for param in params]
        self.sess.run(
            self._global_sync_op_tf,
            feed_dict=dict([
                (ph, param)
                for ph, param in zip(self._global_sync_source_tf, params)
            ]))
示例#14
0
class NAF(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 bc_loss,
                 q_filter,
                 num_demo,
                 demo_batch_size,
                 prm_loss_weight,
                 aux_loss_weight,
                 sample_transitions,
                 gamma,
                 reuse=False,
                 use_seperate_networks=False,
                 **kwargs):

        if self.clip_return is None:
            self.clip_return = np.inf

        if use_seperate_networks:
            self.create_naf_network = import_function(
                "her.naf_utils.naf_network_seperate:Network")
        else:
            self.create_naf_network = import_function(
                "her.naf_utils.naf_network_shared:Network")

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']
        self.counter = 0

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.T - 1 if key != 'o' else self.T, *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T, self.dimg)

        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

        global DEMO_BUFFER
        DEMO_BUFFER = ReplayBuffer(
            buffer_shapes, buffer_size, self.T, self.sample_transitions
        )  #initialize the demo buffer; in the same way as the primary data buffer

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def step(self, obs):
        actions = self.get_actions(obs['observation'], obs['achieved_goal'],
                                   obs['desired_goal'])
        return actions, None, None, None

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.1,
                    use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def init_demo_buffer(
            self,
            demoDataFile,
            update_stats=True):  #function that initializes the demo buffer

        demoData = np.load(
            demoDataFile)  #load the demonstration data from data file
        info_keys = [
            key.replace('info_', '') for key in self.input_dims.keys()
            if key.startswith('info_')
        ]
        info_values = [
            np.empty((self.T - 1, 1, self.input_dims['info_' + key]),
                     np.float32) for key in info_keys
        ]

        demo_data_obs = demoData['obs']
        demo_data_acs = demoData['acs']
        demo_data_info = demoData['info']

        for epsd in range(
                self.num_demo
        ):  # we initialize the whole demo buffer at the start of the training
            obs, acts, goals, achieved_goals = [], [], [], []
            i = 0
            for transition in range(self.T - 1):
                obs.append(
                    [demo_data_obs[epsd][transition].get('observation')])
                acts.append([demo_data_acs[epsd][transition]])
                goals.append(
                    [demo_data_obs[epsd][transition].get('desired_goal')])
                achieved_goals.append(
                    [demo_data_obs[epsd][transition].get('achieved_goal')])
                for idx, key in enumerate(info_keys):
                    info_values[idx][transition,
                                     i] = demo_data_info[epsd][transition][key]

            obs.append([demo_data_obs[epsd][self.T - 1].get('observation')])
            achieved_goals.append(
                [demo_data_obs[epsd][self.T - 1].get('achieved_goal')])

            episode = dict(o=obs, u=acts, g=goals, ag=achieved_goals)
            for key, value in zip(info_keys, info_values):
                episode['info_{}'.format(key)] = value

            episode = convert_episode_to_batch_major(episode)
            global DEMO_BUFFER
            DEMO_BUFFER.store_episode(
                episode
            )  # create the observation dict and append them into the demonstration buffer
            logger.debug("Demo buffer size currently ",
                         DEMO_BUFFER.get_current_size()
                         )  #print out the demonstration buffer size

            if update_stats:
                # add transitions to normalizer to normalize the demo data as well
                episode['o_2'] = episode['o'][:, 1:, :]
                episode['ag_2'] = episode['ag'][:, 1:, :]
                num_normalizing_transitions = transitions_in_episode_batch(
                    episode)
                transitions = self.sample_transitions(
                    episode, num_normalizing_transitions)

                o, g, ag = transitions['o'], transitions['g'], transitions[
                    'ag']
                transitions['o'], transitions['g'] = self._preprocess_og(
                    o, ag, g)
                # No need to preprocess the o_2 and g_2 since this is only used for stats

                self.o_stats.update(transitions['o'])
                self.g_stats.update(transitions['g'])

                self.o_stats.recompute_stats()
                self.g_stats.recompute_stats()
            episode.clear()

        logger.info("Demo buffer size: ", DEMO_BUFFER.get_current_size()
                    )  #print out the demonstration buffer size

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            o, g, ag = transitions['o'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        Q_grad = self.sess.run([self.Q_grad_tf])
        return Q_grad

    def _update(self, Q_grad):
        self.adam.update(Q_grad, self.Q_lr)

    def sample_batch(self):
        if self.bc_loss:  #use demonstration buffer to sample as well if bc_loss flag is set TRUE
            transitions = self.buffer.sample(self.batch_size -
                                             self.demo_batch_size)
            global DEMO_BUFFER
            transitions_demo = DEMO_BUFFER.sample(
                self.demo_batch_size)  #sample from the demo buffer
            for k, values in transitions_demo.items():
                rolloutV = transitions[k].tolist()
                for v in values:
                    rolloutV.append(v.tolist())
                transitions[k] = np.array(rolloutV)
        else:
            transitions = self.buffer.sample(
                self.batch_size)  #otherwise only sample from primary buffer

        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        Q_grad = self._grads()
        self._update(Q_grad[0])

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dimu, self.max_u))
        self.sess = tf_util.get_session()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        #choose only the demo buffer samples
        mask = np.concatenate(
            (np.zeros(self.batch_size - self.demo_batch_size),
             np.ones(self.demo_batch_size)),
            axis=0)

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_naf_network(batch_tf,
                                                net_type='main',
                                                **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_naf_network(target_batch_tf,
                                                  net_type='target',
                                                  **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_value = self.target.value
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_value,
                                     *clip_range)
        self.Q_loss_tf = tf.reduce_mean(tf.square(target_tf - self.main.Q))
        tf.summary.histogram("Q_loss", self.Q_loss_tf)

        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main'))

        assert len(self._vars('main')) == len(Q_grads_tf)

        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main'))

        # optimizers
        self.adam = MpiAdam(self._vars('main'), scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main')
        self.target_vars = self._vars('target')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

        visualize = True
        if visualize:
            writer = tf.summary.FileWriter("output", self.sess.graph)
            writer.close()
            saver = tf.train.Saver()
            saver.save(self.sess, "./models/model.ckpt")

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix != '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)

    def save(self, save_path):
        tf_util.save_variables(save_path)
示例#15
0
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 normalize_obs,
                 sample_transitions,
                 gamma,
                 buffers=None,
                 reuse=False,
                 tasks_ag_id=None,
                 tasks_g_id=None,
                 task_replay='',
                 t_id=None,
                 eps_task=None,
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused,
            buffers (list): buffers to be used to store new transition (usually one per task + 1
            task_ag_id (list): indices to find achieved goals for each task in the achieved goal vector
            task_g_id (list): indices to find agoals for each task in the goal vector
            task_replay (str): defines the task replay strategy (see train.py for info)
            t_id (int): index of the task corresponding to this policy when using a task-experts structure
            eps_task (float): epsilon parameter for the epsilon greedy strategy (task choice)
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)
        self.normalize_obs = normalize_obs

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimag = self.input_dims['ag']
        self.dimu = self.input_dims['u']
        if self.structure == 'curious' or self.structure == 'task_experts':
            self.dimtd = self.input_dims['task_descr']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, 1)
        self.stage_shapes = stage_shapes

        if t_id is not None:
            self.scope += str(t_id)
        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # addition for multi-task structures
        if self.structure == 'curious' or self.structure == 'task_experts':
            self.tasks_g_id = tasks_g_id
            self.tasks_ag_id = tasks_ag_id
            self.nb_tasks = len(tasks_g_id)

        if buffers is not None:
            self.buffer = buffers
            if type(self.buffer) is list:
                if len(self.buffer) > 5:
                    # distractor buffers are equal
                    for i in range(6, len(self.buffer)):
                        self.buffer[i] = self.buffer[5]
        self.first = True

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimag)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self,
                    o,
                    ag,
                    g,
                    task_descr=None,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }
        # addition for multi-task structures
        if self.structure == 'curious' or self.structure == 'task_experts':
            feed[policy.td_tf] = task_descr.reshape(-1, self.dimtd)

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, cp, n_ep, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """
        # decompose episode_batch in episodes
        batch_size = episode_batch['ag'].shape[0]
        # addition in the case of curious goals, compute count of achieved goal that moved in the n modules
        self.cp = cp
        self.n_episodes = n_ep
        # addition for multi-task structures
        if self.structure == 'curious' or self.structure == 'task_experts':
            new_count_local = np.zeros([self.nb_tasks])
            new_count_total = np.zeros([self.nb_tasks])
            # add a new transition in a buffer only if the corresponding outcome has changed compare to the initial outcome
            for b in range(batch_size):
                active_tasks = []
                for j in range(self.nb_tasks):
                    if any(episode_batch['change']
                           [b, -1,
                            self.tasks_ag_id[j][:len(self.tasks_g_id[j])]]):
                        new_count_local[j] += 1
                        if self.nb_tasks < 5 or j < 5:
                            active_tasks.append(j)
                MPI.COMM_WORLD.Allreduce(new_count_local,
                                         new_count_total,
                                         op=MPI.SUM)
                ep = dict()
                for key in episode_batch.keys():
                    ep[key] = episode_batch[key][b].reshape([
                        1, episode_batch[key].shape[1],
                        episode_batch[key].shape[2]
                    ])

                if 'buffer' in self.task_replay or self.task_replay == 'hand_designed':
                    if len(active_tasks) == 0:
                        ind_buffer = [0]
                    else:
                        for task in active_tasks:
                            self.buffer[task + 1].store_episode(ep)
                else:
                    self.buffer.store_episode(ep)

        elif self.structure == 'flat' or self.structure == 'task_experts':
            for b in range(batch_size):
                ep = dict()
                for key in episode_batch.keys():
                    ep[key] = episode_batch[key][b].reshape([
                        1, episode_batch[key].shape[1],
                        episode_batch[key].shape[2]
                    ])
                self.buffer.store_episode(ep)

        # update statistics for goal and observation normalizations
        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            if self.structure == 'curious' or self.structure == 'task_experts':
                transitions = self.sample_transitions(
                    episode_batch,
                    num_normalizing_transitions,
                    task_to_replay=None)
            else:
                transitions = self.sample_transitions(
                    episode_batch, num_normalizing_transitions)
            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])
            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return sum(
            [self.buffer[i].get_current_size() for i in range(self.nb_tasks)])

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf, self.main.Q_pi_tf, self.Q_grad_tf, self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        # addition for multi-task structures
        if self.structure == 'curious' or self.structure == 'task_experts':
            if self.structure == 'curious':
                if 'buffer' in self.task_replay or self.task_replay == 'hand_designed':
                    buffers_sizes = np.array([
                        self.buffer[i].current_size * self.T
                        for i in range(self.nb_tasks + 1)
                    ])
                    self.proportions = np.zeros([self.nb_tasks + 1])
                    buffers_sizes = np.array([
                        self.buffer[i].current_size * self.T
                        for i in range(self.nb_tasks + 1)
                    ])
                    self.proportions = np.zeros([self.nb_tasks + 1])
                    if buffers_sizes[1:].sum() < self.T:
                        ind_valid_buffers = np.array([0])
                        n_valid = 1
                        self.proportions = buffers_sizes / buffers_sizes.sum(
                        ) * self.batch_size
                    else:
                        ind_valid_buffers = np.argwhere(buffers_sizes[1:] > 0)
                        ind_valid_buffers = ind_valid_buffers.reshape(
                            [ind_valid_buffers.size])
                        n_valid = len(ind_valid_buffers)

                        # draw transition from random buffers (random tasks)
                        if self.task_replay == 'replay_task_random_buffer':
                            proba = 1 / ind_valid_buffers.size * np.ones(
                                [n_valid])
                        elif self.task_replay == 'replay_task_cp_buffer':
                            CP = self.cp[ind_valid_buffers]
                            if CP.sum() == 0:
                                proba = (1 / n_valid) * np.ones([n_valid])
                            else:
                                proba = self.eps_task * (1 / n_valid) * np.ones([n_valid]) + \
                                         (1 - self.eps_task) * CP / CP.sum()
                            proba[-1] = 1 - proba[:-1].sum()
                        self.proportions[ind_valid_buffers +
                                         1] = proba * self.batch_size

                    self.proportions = self.proportions.astype(np.int)
                    remain = self.batch_size - self.proportions.sum()
                    for i in range(remain):
                        self.proportions[ind_valid_buffers[i % n_valid] +
                                         1] += 1
                    self.proportions = self.proportions.astype(np.int)

                elif self.task_replay == 'replay_cp_task_transition':
                    CP = self.cp.copy()
                    if CP.sum() == 0:
                        proba = (1 / self.nb_tasks) * np.ones([self.nb_tasks])
                    else:
                        proba = self.eps_task * (1 / self.nb_tasks) * np.ones([self.nb_tasks]) + \
                                (1 - self.eps_task) * CP / CP.sum()
                    proba[-1] = 1 - proba[:-1].sum()
                    transitions = self.buffer.sample(self.batch_size,
                                                     task_to_replay=None,
                                                     cp_proba=proba)

                else:
                    transitions = self.buffer.sample(self.batch_size,
                                                     task_to_replay=None,
                                                     cp_proba=None)

            elif self.structure == 'task_experts':
                if self.task_replay == 'replay_current_task_buffer':
                    buffers_sizes = np.array([
                        self.buffer[i].current_size * self.T
                        for i in range(self.nb_tasks + 1)
                    ])
                    ind_valid_buffers = np.argwhere(buffers_sizes > 0)
                    ind_valid_buffers = ind_valid_buffers.reshape(
                        [ind_valid_buffers.size])
                    n_valid = len(ind_valid_buffers)
                    self.proportions = np.zeros([self.nb_tasks + 1])
                    if buffers_sizes[self.t_id + 1] > 0:
                        self.proportions[self.t_id + 1] = 1
                    else:
                        self.proportions[ind_valid_buffers] = 1 / len(
                            ind_valid_buffers)
                    self.proportions *= self.batch_size
                    self.proportions = self.proportions.astype(np.int)
                    remain = self.batch_size - self.proportions.sum()
                    for i in range(remain):
                        self.proportions[ind_valid_buffers[i % n_valid]] += 1
                    self.proportions = self.proportions.astype(np.int)
                else:
                    transitions = self.buffer.sample(self.batch_size,
                                                     task_to_replay=None,
                                                     cp_proba=None)

            if 'buffer' in self.task_replay or self.task_replay == 'hand_designed':
                assert self.proportions.sum() == self.batch_size

                # sample transitions from different buffers
                trans = []
                for i in range(self.nb_tasks + 1):
                    if self.proportions[i] > 0:
                        if self.structure == 'curious':
                            if i > 0:
                                task_to_replay = i - 1
                            else:
                                task_to_replay = None
                        else:
                            task_to_replay = self.t_id
                        trans.append(self.buffer[i].sample(
                            self.proportions[i],
                            task_to_replay=task_to_replay))
                # concatenate transitions from different buffers and shuffle
                shuffle_inds = np.arange(self.batch_size)
                np.random.shuffle(shuffle_inds)
                transitions = dict()
                for key in trans[0].keys():
                    tmp = np.array([]).reshape([0, trans[0][key].shape[1]])
                    for ts in trans:
                        tmp = np.concatenate([tmp, ts[key]])
                    transitions[key] = tmp[shuffle_inds, :]

        elif self.structure == 'flat':
            transitions = self.buffer.sample(self.batch_size)

        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        # #test addition !!
        # transitions['task_descr'] = np.zeros([256, 8])

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]

        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        for i in range(self.nb_tasks):
            self.buffer[i].clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        if MPI.COMM_WORLD.Get_rank() == 0:
            logger.info("g a DDPG agent with action space %d x %s..." %
                        (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def save_weights(self, path):
        to_save = []
        scopes_var = ['main/Q', 'main/pi', 'target/Q', 'target/pi']
        scopes_global_var = ['o_stats', 'g_stats']
        for s in scopes_var:
            tmp = []
            for v in self._vars(s):
                tmp.append(v.eval())
            to_save.append(tmp)
        for s in scopes_global_var:
            tmp = []
            for v in self._global_vars(s):
                tmp.append(v.eval())
            to_save.append(tmp)

        with open(path + '_weights.pkl', 'wb') as f:
            pickle.dump(to_save, f)

    def load_weights(self, path):
        with open(path + '_weights.pkl', 'rb') as f:
            weights = pickle.load(f)
        scopes_var = ['main/Q', 'main/pi', 'target/Q', 'target/pi']
        scopes_global_var = ['o_stats', 'g_stats']
        for i_s, s in enumerate(scopes_var):
            for i_v, v in enumerate(self._vars(s)):
                v.load(weights[i_s][i_v])
        for i_s, s in enumerate(scopes_global_var):
            for i_v, v in enumerate(self._global_vars(s)):
                v.load(weights[i_s + len(scopes_var)][i_v])

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        # state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
class DDPG_HER_HRL_POLICY(HRL_Policy):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 sample_transitions,
                 gamma,
                 reuse=False,
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        self.ep_ctr = 0
        self.hist_bins = 50
        self.draw_hist_freq = 3
        self._reset_hists()
        self.shared_pi_err_coeff = kwargs['shared_pi_err_coeff']

        HRL_Policy.__init__(self, input_dims, T, rollout_batch_size, **kwargs)

        self.hidden = hidden
        self.layers = layers
        self.max_u = max_u
        self.network_class = network_class
        self.sample_transitions = sample_transitions
        self.scope = scope
        self.subtract_goals = subtract_goals
        self.relative_goals = relative_goals
        self.clip_obs = clip_obs
        self.Q_lr = Q_lr
        self.pi_lr = pi_lr
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.clip_pos_returns = clip_pos_returns
        self.gamma = gamma
        self.polyak = polyak
        self.clip_return = clip_return
        self.norm_eps = norm_eps
        self.norm_clip = norm_clip
        self.action_l2 = action_l2
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)
        self.stage_shapes['gamma'] = (None, )
        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key:
            (self.T if key != 'o' else self.T + 1, *self.input_shapes[key])
            for key, val in self.input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T + 1, self.dimg)
        buffer_shapes['p'] = (buffer_shapes['g'][0], 1)
        buffer_shapes['steps'] = buffer_shapes['p']
        buffer_size = self.buffer_size  #// self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

        self.preproc_lr = (self.Q_lr + self.pi_lr) / 2

    def _reset_hists(self):
        self.hists = {"attn": None, "prob_in": None, "rnd": None}

    def draw_hists(self, img_dir):
        for hist_name, hist in self.hists.items():
            if hist is None:
                continue
            step_size = 1.0 / self.hist_bins
            xs = np.arange(0, 1, step_size)
            hist /= (self.ep_ctr * self.T)
            fig, ax = plt.subplots()
            ax.bar(xs, hist, step_size)
            plt.savefig(img_dir + "/{}_hist_l_{}_ep_{}.png".format(
                hist_name, self.h_level, self.ep_ctr))
        self._reset_hists()
        if self.child_policy is not None:
            self.child_policy.draw_hists(img_dir)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False,
                    exploit=True,
                    **kwargs):
        # noise_eps = noise_eps if not exploit else 0.
        # random_eps = random_eps if not exploit else 0.

        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf, policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        q = ret[1]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        noisy_u = u + noise
        u = np.clip(noisy_u, -self.max_u, self.max_u)
        random_u = np.random.binomial(1, random_eps, u.shape[0]).reshape(
            -1, 1) * (self._random_action(u.shape[0]) - noisy_u)  # eps-greedy
        u += random_u
        u = u[0].copy()
        self.update_hists(feed, policy)
        return u, q

    def update_hists(self, feed, policy):
        vals = []
        hist_names_to_consider = []
        for hist_name, hist in self.hists.items():
            if hist_name in self.main.__dict__:
                hist_names_to_consider.append(hist_name)
                vals.append(eval("policy.{}".format(hist_name)))

        ret = self.sess.run(vals, feed_dict=feed)
        for val_idx, hist_name in enumerate(hist_names_to_consider):
            this_vals = ret[val_idx]
            this_hists = np.histogram(this_vals, self.hist_bins, range=(0, 1))
            if self.hists[hist_name] is None:
                self.hists[hist_name] = this_hists[0] / this_vals.shape[1]
            else:
                self.hists[hist_name] += this_hists[0] / this_vals.shape[1]

    def scale_and_offset_action(self, u):
        scaled_u = u.copy()
        scaled_u *= self.subgoal_scale
        scaled_u += self.subgoal_offset
        return scaled_u

    def inverse_scale_and_offset_action(self, scaled_u):
        u = scaled_u.copy()
        u -= self.subgoal_offset
        u /= self.subgoal_scale
        return u

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """
        # print("Storing Episode h-level = {}".format(self.h_level))
        self.buffer.store_episode(episode_batch)
        if update_stats:
            # add transitions to normalizer

            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            # num_normalizing_transitions = episode_batch['u'].shape[1]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.o_stats.recompute_stats()

            self.g_stats.update(transitions['g'])
            self.g_stats.recompute_stats()

        self.ep_ctr += 1
        # print("Done storing Episode")

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()
        self.shared_preproc_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, preproc_loss, Q_grad, pi_grad, preproc_grad = self.sess.run(
            [
                self.Q_loss_tf, self.main.Q_pi_tf, self.shared_preproc_loss_tf,
                self.Q_grad_tf, self.pi_grad_tf, self.shared_preproc_grad_tf
            ])
        return critic_loss, actor_loss, preproc_loss, Q_grad, pi_grad, preproc_grad

    def _update(self, Q_grad, pi_grad, preproc_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)
        self.shared_preproc_adam.update(preproc_grad, self.preproc_lr)

    def sample_batch(self):
        transitions = self.buffer.sample(self.batch_size)
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)
        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, preproc_loss, Q_grad, pi_grad, preproc_grad = self._grads(
        )
        self._update(Q_grad, pi_grad, preproc_grad)
        return critic_loss, actor_loss, preproc_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        # assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG_HRL agent with action space %d x %s..." %
                    (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        # target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        target_tf = tf.clip_by_value(
            batch_tf['r'] +
            tf.transpose(self.gamma * batch_tf['gamma']) * target_Q_pi_tf,
            *clip_range)
        self.Q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))
        self.shared_q_err_coeff = 1.0 - self.shared_pi_err_coeff
        self.shared_preproc_loss_tf = (
            self.shared_q_err_coeff * self.Q_loss_tf +
            self.shared_pi_err_coeff * self.pi_loss_tf)
        if "shared_preproc_err" in self.main.__dict__:
            self.shared_preproc_loss_tf += self.main.shared_preproc_err
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        shared_preproc_grads_tf = tf.gradients(
            self.shared_preproc_loss_tf, self._vars('main/shared_preproc'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        assert len(
            self._vars('main/shared_preproc')) == len(shared_preproc_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.shared_preproc_grads_vars_tf = zip(
            shared_preproc_grads_tf, self._vars('main/shared_preproc'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))
        self.shared_preproc_grad_tf = flatten_grads(
            grads=shared_preproc_grads_tf,
            var_list=self._vars('main/shared_preproc'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)
        self.shared_preproc_adam = MpiAdam(self._vars('main/shared_preproc'),
                                           scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars(
            'main/pi') + self._vars('main/shared_preproc')
        self.target_vars = self._vars('target/Q') + self._vars(
            'target/pi') + self._vars('target/shared_preproc')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix='policy'):
        logs = []
        logs += [('buffer_size', int(self.buffer.current_size))]
        logs = log_formater(logs, prefix + "_{}".format(self.h_level))
        if self.child_policy is not None:
            child_logs = self.child_policy.logs(prefix=prefix)
            logs += child_logs

        return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 action_scale,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 bc_loss,
                 q_filter,
                 num_demo,
                 demo_batch_size,
                 prm_loss_weight,
                 aux_loss_weight,
                 sample_transitions,
                 gamma,
                 temperature,
                 prioritization,
                 env_name,
                 alpha,
                 beta0,
                 beta_iters,
                 total_timesteps,
                 rank_method,
                 reuse=False,
                 **kwargs):
        """
            Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
            Added functionality to use demonstrations for training to Overcome exploration problem.
        Args:
            :param input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            :param buffer_size (int): number of transitions that are stored in the replay buffer
            :param hidden (int): number of units in the hidden layers
            :param layers (int): number of hidden layers
            :param network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            :param polyak (float): coefficient for Polyak-averaging of the target network
            :param batch_size (int): batch size for training
            :param Q_lr (float): learning rate for the Q (critic) network
            :param pi_lr (float): learning rate for the pi (actor) network
            :param norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            :param norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            :param action_scale(float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            :param action_l2 (float): coefficient for L2 penalty on the actions
            :param clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            :param scope (str): the scope used for the TensorFlow graph
            :param T (int): the time horizon for rollouts
            :param rollout_batch_size (int): number of parallel rollouts per DDPG agent
            :param subtract_goals (function): function that subtracts goals from each other
            :param relative_goals (boolean): whether or not relative goals should be fed into the network
            :param clip_pos_returns (boolean): whether or not positive returns should be clipped
            :param clip_return (float): clip returns to be in [-clip_return, clip_return]
            :param sample_transitions (function) function that samples from the replay buffer
            :param gamma (float): gamma used for Q learning updates
            :param reuse (boolean): whether or not the networks should be reused
            :param bc_loss: whether or not the behavior cloning loss should be used as an auxilliary loss
            :param q_filter: whether or not a filter on the q value update should be used when training with demonstartions
            :param num_demo: Number of episodes in to be used in the demonstration buffer
            :param demo_batch_size: number of samples to be used from the demonstrations buffer, per mpi thread
            :param prm_loss_weight: Weight corresponding to the primary loss
            :param aux_loss_weight: Weight corresponding to the auxilliary loss also called the cloning loss
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(
            self.network_class)  # points to actor_critic.py

        self.input_dims = input_dims

        input_shapes = dims_to_shapes(input_dims)
        self.dimo = input_dims['o']
        self.dimg = input_dims['g']
        self.dimu = input_dims['u']

        self.sample_count = 1
        self.cycle_count = 1

        self.critic_loss_episode = []
        self.actor_loss_episode = []
        self.critic_loss_avg = []
        self.actor_loss_avg = []

        # Energy based parameters
        self.prioritization = prioritization
        self.env_name = env_name
        self.temperature = temperature
        self.rank_method = rank_method

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)  # Creates DDPG agent

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.T - 1 if key != 'o' else self.T, *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T, self.dimg)

        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size

        # print("begin init")
        if self.prioritization == 'energy':
            self.buffer = ReplayBufferEnergy(buffer_shapes, buffer_size,
                                             self.T, self.sample_transitions,
                                             self.prioritization,
                                             self.env_name)
        # elif self.prioritization == 'tderror':
        #     self.buffer = PrioritizedReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions, alpha)
        #     if beta_iters is None:
        #         beta_iters = total_timesteps
        #     self.beta_schedule = LinearSchedule(beta_iters, initial_p=beta0, final_p=1.0)
        else:
            self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                       self.sample_transitions)

        # print("finish init")

    def _random_action(self, n):
        return np.random.uniform(low=-self.action_scale,
                                 high=self.action_scale,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:  # no self.relative_goals
            print("self.relative_goals: ", self.relative_goals)
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)

        # Clip (limit) the values in an array.
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)

        return o, g

    # Not used
    def step(self, obs):
        actions = self.get_actions(obs['observation'], obs['achieved_goal'],
                                   obs['desired_goal'])
        return actions, None, None, None

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):

        o, g = self._preprocess_og(o, ag, g)

        # Use target network use main network
        policy = self.target if use_target_net else self.main

        # values to compute
        policy_weights = [policy.actor_tf]

        if compute_Q:
            policy_weights += [policy.critic_with_actor_tf]

        # feeds
        agent_feed = {
            policy.obs:
            o.reshape(-1, self.dimo),
            policy.goals:
            g.reshape(-1, self.dimg),
            policy.actions:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        # Evaluating policy weights with agent information
        ret = self.sess.run(policy_weights, feed_dict=agent_feed)

        # print(ret)

        # action postprocessing
        action = ret[0]
        noise = noise_eps * self.action_scale * np.random.randn(
            *action.shape)  # gaussian noise
        action += noise
        action = np.clip(action, -self.action_scale, self.action_scale)
        action += np.random.binomial(1, random_eps, action.shape[0]).reshape(
            -1, 1) * (self._random_action(action.shape[0]) - action
                      )  # eps-greedy
        if action.shape[0] == 1:
            action = action[0]
        action = action.copy()
        ret[0] = action

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    # Not used
    # def init_demo_buffer(self, demoDataFile, update_stats=True):  # function that initializes the demo buffer
    #
    #     demoData = np.load(demoDataFile)  # load the demonstration data from data file
    #     info_keys = [key.replace('info_', '') for key in self.input_dims.keys() if key.startswith('info_')]
    #     info_values = [np.empty((self.T - 1, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys]
    #
    #     demo_data_obs = demoData['obs']
    #     demo_data_acs = demoData['acs']
    #     demo_data_info = demoData['info']
    #
    #     for epsd in range(self.num_demo):  # we initialize the whole demo buffer at the start of the training
    #         obs, acts, goals, achieved_goals = [], [], [], []
    #         i = 0
    #         for transition in range(self.T - 1):
    #             obs.append([demo_data_obs[epsd][transition].get('observation')])
    #             acts.append([demo_data_acs[epsd][transition]])
    #             goals.append([demo_data_obs[epsd][transition].get('desired_goal')])
    #             achieved_goals.append([demo_data_obs[epsd][transition].get('achieved_goal')])
    #             for idx, key in enumerate(info_keys):
    #                 info_values[idx][transition, i] = demo_data_info[epsd][transition][key]
    #
    #         obs.append([demo_data_obs[epsd][self.T - 1].get('observation')])
    #         achieved_goals.append([demo_data_obs[epsd][self.T - 1].get('achieved_goal')])
    #
    #         episode = dict(observations=obs,
    #                        u=acts,
    #                        g=goals,
    #                        ag=achieved_goals)
    #         for key, value in zip(info_keys, info_values):
    #             episode['info_{}'.format(key)] = value
    #
    #         episode = convert_episode_to_batch_major(episode)
    #         global DEMO_BUFFER
    #         DEMO_BUFFER.ddpg_store_episode(
    #             episode)  # create the observation dict and append them into the demonstration buffer
    #         logger.debug("Demo buffer size currently ",
    #                      DEMO_BUFFER.get_current_size())  # print out the demonstration buffer size
    #
    #         if update_stats:
    #             # add transitions to normalizer to normalize the demo data as well
    #             episode['o_2'] = episode['o'][:, 1:, :]
    #             episode['ag_2'] = episode['ag'][:, 1:, :]
    #             num_normalizing_transitions = transitions_in_episode_batch(episode)
    #             transitions = self.sample_transitions(episode, num_normalizing_transitions)
    #
    #             o, g, ag = transitions['o'], transitions['g'], transitions['ag']
    #             transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
    #             # No need to preprocess the o_2 and g_2 since this is only used for stats
    #
    #             self.o_stats.update(transitions['o'])
    #             self.g_stats.update(transitions['g'])
    #
    #             self.o_stats.recompute_stats()
    #             self.g_stats.recompute_stats()
    #         episode.clear()
    #
    #     logger.info("Demo buffer size: ", DEMO_BUFFER.get_current_size())  # print out the demonstration buffer size

    def ddpg_store_episode(self,
                           episode_batch,
                           dump_buffer,
                           w_potential,
                           w_linear,
                           w_rotational,
                           rank_method,
                           clip_energy,
                           update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        # if self.prioritization == 'tderror':
        #     self.buffer.store_episode(episode_batch, dump_buffer)

        # print("DDPG BEGIN STORE episode")
        if self.prioritization == 'energy':
            self.buffer.store_episode(episode_batch, w_potential, w_linear,
                                      w_rotational, rank_method, clip_energy)
        else:
            self.buffer.store_episode(episode_batch)

        # print("DDPG END STORE episode")

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            # print("START ddpg sample transition")
            # n_cycles calls HER sampler
            if self.prioritization == 'energy':
                if not self.buffer.current_size == 0 and not len(
                        episode_batch['ag']) == 0:
                    transitions = self.sample_transitions(
                        episode_batch, num_normalizing_transitions, 'none',
                        1.0, self.sample_count, self.cycle_count, True)
            # elif self.prioritization == 'tderror':
            #     transitions, weights, episode_idxs = \
            #         self.sample_transitions(self.buffer, episode_batch, num_normalizing_transitions, beta=0)
            else:
                transitions = self.sample_transitions(
                    episode_batch, num_normalizing_transitions)
            # print("END ddpg sample transition")
            # print("DDPG END STORE episode 2")
            o, g, ag = transitions['o'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.critic_optimiser.sync()
        self.actor_optimiser.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, critic_grad, actor_grad, td_error = self.sess.run(
            [
                self.critic_loss_tf,  # MSE of target_tf - main.critic_tf
                self.main.critic_with_actor_tf,  # actor_loss
                self.critic_grads,
                self.actor_grads,
                self.td_error_tf
            ])
        return critic_loss, actor_loss, critic_grad, actor_grad, td_error

    def _update(self, critic_grads, actor_grads):
        self.critic_optimiser.update(critic_grads, self.Q_lr)
        self.actor_optimiser.update(actor_grads, self.pi_lr)

    def sample_batch(self, t):
        # print("Begin Sample batch")
        if self.prioritization == 'energy':
            transitions = self.buffer.sample(self.batch_size,
                                             self.rank_method,
                                             temperature=self.temperature)
            weights = np.ones_like(transitions['r']).copy()
            # print("reach?")
        # elif self.prioritization == 'tderror':
        #     transitions, weights, idxs = self.buffer.sample(self.batch_size, beta=self.beta_schedule.value(t))
        else:
            transitions = self.buffer.sample(self.batch_size)
            weights = np.ones_like(transitions['r']).copy()

        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        transitions['w'] = weights.flatten().copy()  # note: ordered dict
        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        # if self.prioritization == 'tderror':
        #     return (transitions_batch, idxs)
        # else:
        # print("End sample batch")
        return transitions_batch

    def stage_batch(self, t, batch=None):
        if batch is None:
            # if self.prioritization == 'tderror':
            #     batch, idxs = self.sample_batch(t)
            # else:
            batch = self.sample_batch(t)
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

        # if self.prioritization == 'tderror':
        #     return idxs

    def ddpg_train(self, t, dump_buffer, stage=True):
        if stage:
            # if self.prioritization == 'tderror':
            #     idxs = self.stage_batch(t)
            # else:
            self.stage_batch(t)

        self.critic_loss, self.actor_loss, Q_grad, pi_grad, td_error = self._grads(
        )

        # if self.prioritization == 'tderror':
        #     new_priorities = np.abs(td_error) + self.eps  # td_error
        #     if dump_buffer:
        #         T = self.buffer.buffers['u'].shape[1]
        #         episode_idxs = idxs // T
        #         t_samples = idxs % T
        #         batch_size = td_error.shape[0]
        #         with self.buffer.lock:
        #             for i in range(batch_size):
        #                 self.buffer.buffers['td'][episode_idxs[i]][t_samples[i]] = td_error[i]
        #
        #     self.buffer.update_priorities(idxs, new_priorities)

        # Update gradients for actor and critic networks
        self._update(Q_grad, pi_grad)

        # My variables
        self.visual_actor_loss = 1 - self.actor_loss
        self.critic_loss_episode.append(self.critic_loss)
        self.actor_loss_episode.append(self.visual_actor_loss)

        # print("Critic loss: ", self.critic_loss, " Actor loss: ", self.actor_loss)
        return self.critic_loss, np.mean(self.actor_loss)

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def ddpg_update_target_net(self):
        # print("ddpg_cycle", self.cycle_count)
        self.cycle_count += 1
        self.critic_loss_avg = np.mean(self.critic_loss_episode)
        self.actor_loss_avg = np.mean(self.actor_loss_episode)

        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dimu, self.action_scale))
        self.sess = tf_util.get_session()

        # running averages
        with tf.variable_scope('o_stats') as variable_scope:
            if reuse:
                variable_scope.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as variable_scope:
            if reuse:
                variable_scope.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # choose only the demo buffer samples
        mask = np.concatenate(
            (np.zeros(self.batch_size - self.demo_batch_size),
             np.ones(self.demo_batch_size)),
            axis=0)

        # networks
        with tf.variable_scope('main') as variable_scope:
            if reuse:
                variable_scope.reuse_variables()

            # Create actor critic network
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            variable_scope.reuse_variables()

        with tf.variable_scope('target') as variable_scope:
            if reuse:
                variable_scope.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            variable_scope.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_critic_actor_tf = self.target.critic_with_actor_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)

        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_critic_actor_tf, *clip_range)

        # MSE of target_tf - critic_tf. This is the TD Learning step
        self.td_error_tf = tf.stop_gradient(target_tf) - self.main.critic_tf
        self.critic_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.critic_tf))

        #
        self.actor_loss_tf = -tf.reduce_mean(self.main.critic_with_actor_tf)
        self.actor_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.actor_tf / self.action_scale))

        # Constructs symbolic derivatives of sum of critic_loss_tf vs _vars('main/Q')
        critic_grads_tf = tf.gradients(self.critic_loss_tf,
                                       self._vars('main/Q'))
        actor_grads_tf = tf.gradients(self.actor_loss_tf,
                                      self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(critic_grads_tf)
        assert len(self._vars('main/pi')) == len(actor_grads_tf)
        self.critic_grads_vars_tf = zip(critic_grads_tf, self._vars('main/Q'))
        self.actor_grads_vars_tf = zip(actor_grads_tf, self._vars('main/pi'))

        # Flattens variables and their gradients.
        self.critic_grads = flatten_grads(grads=critic_grads_tf,
                                          var_list=self._vars('main/Q'))
        self.actor_grads = flatten_grads(grads=actor_grads_tf,
                                         var_list=self._vars('main/pi'))

        # optimizers
        self.critic_optimiser = MpiAdam(self._vars('main/Q'),
                                        scale_grad_by_procs=False)
        self.actor_optimiser = MpiAdam(self._vars('main/pi'),
                                       scale_grad_by_procs=False)

        # polyak averaging used to update target network
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')

        # list( map( lambda( assign() ), zip()))
        self.init_target_net_op = list(
            map(  # Apply lambda to each item item in the zipped list
                lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))

        # Polyak-Ruppert averaging where most recent iterations are weighted more than past iterations.
        self.update_target_net_op = list(
            map(  # Apply lambda to each item item in the zipped list
                lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) *
                                      v[1]),  # polyak averaging
                zip(self.target_vars,
                    self.main_vars))  # [(target_vars, main_vars), (), ...]
        )

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('actor_critic/critic_loss', self.critic_loss_avg)]
        logs += [('actor_critic/actor_loss', self.actor_loss_avg)]

        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        # logs += [('critic_loss', np.mean(self.sess.run([self.critic_loss])))]
        # logs += [('actor_loss', np.mean(self.sess.run([self.actor_loss])))]

        if prefix != '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)

    def save(self, save_path):
        tf_util.save_variables(save_path)
示例#18
0
文件: ddpg.py 项目: poisonwine/GHER
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 sample_transitions,
                 gamma,
                 reuse=False,
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'GHER.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """

        # # print("\n\n\n\n1--", input_dims, "\n2--", buffer_size, "\n3--", hidden,
        #         "\n4--", layers, "\n5--", network_class, "\n6--", polyak, "\n7--", batch_size,
        #          "\n8--", Q_lr, "\n9--", pi_lr, "\n10--", norm_eps, "\n11--", norm_clip,
        #          "\n12--", max_u, "\n13--", action_l2, "\n14--", clip_obs, "\n15--", scope, "\n16--", T,
        #          "\n17--", rollout_batch_size, "\n18--", subtract_goals, "\n19--", relative_goals,
        #          "\n20--", clip_pos_returns, "\n21--", clip_return,
        #          "\n22--", sample_transitions, "\n23--", gamma)
        """
        在FetchReach-v1运行中参数值示例:
            input_dims (dict of ints):  {'o': 10, 'u': 4, 'g': 3, 'info_is_success': 1}  (o,u,g均作为网络的输入) 
            buffer_size (int):  1E6     (经验池样本总数)
            hidden (int): 256          (隐含层神经元个数)
            layers (int): 3            (三层神经网络)
            network_class (str):        GHER.ActorCritic'
            polyak (float): 0.95       (target-Network更新的平滑的参数)
            batch_size (int): 256      (批量大小)
            Q_lr (float): 0.001         (学习率)
            pi_lr (float): 0.001        (学习率)
            norm_eps (float): 0.01      (为避免数据溢出使用)
            norm_clip (float): 5        (norm_clip)
            max_u (float): 1.0          (动作的范围是[-1.0, 1.0])
            action_l2 (float): 1.0      (Actor网络的损失正则项系数)
            clip_obs (float): 200       (obs限制在 (-200, +200))
            scope (str): "ddpg"         (tensorflow 使用的 scope 命名域)
            T (int): 50                 (周期的交互次数)
            rollout_batch_size (int): 2 (number of parallel rollouts per DDPG agent)
            subtract_goals (function):  对goal进行预处理的函数, 输入为a和b,输出a-b
            relative_goals (boolean):   False  (如果需要对goal进行函数subtract_goals处理,则为True)
            clip_pos_returns (boolean): True   (是否需要将正的return消除)
            clip_return (float): 50     (将return的范围限制在[-clip_return, clip_return])
            sample_transitions (function):  her返回的函数. 参数由 config.py 定义
            gamma (float): 0.98         (Q 网络更新时使用的折扣因子)

            其中 sample_transition 来自与 HER 的定义,是关键部分
        """

        if self.clip_return is None:
            self.clip_return = np.inf

        # 网络结构和计算图的创建由 actor_critic.py 文件完成
        self.create_actor_critic = import_function(self.network_class)

        # 提取维度
        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']  # 10
        self.dimg = self.input_dims['g']  # 4
        self.dimu = self.input_dims['u']  # 3
        # print("+++", input_shapes)    #  {'o': (10,), 'u': (4,), 'g': (3,), 'info_is_success': (1,)}

        # https://www.tensorflow.org/performance/performance_models
        # StagingArea 提供了更简单的功能且可在 CPU 和 GPU 中与其他阶段并行执行。
        #       将输入管道拆分为 3 个独立并行操作的阶段,并且这是可扩展的,充分利用大型的多核环境

        # 定义需要的存储变量. 假设 self.dimo=10, self.dimg=5, self.dimu=5
        # 则 state_shapes={'o':(None, 10), 'g':(None, 5), 'u':(None:5)}
        # 同时添加target网络使用的变量 state_shapes={'o_2':(None, 10), 'g_2': (None, 5)}
        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )  # 奖励为标量
        self.stage_shapes = stage_shapes
        # 执行后 self.stage_shapes =
        #       OrderedDict([('g', (None, 3)), ('o', (None, 10)), ('u', (None, 4)), ('o_2', (None, 10)), ('g_2', (None, 3)), ('r', (None,))])
        # 其中包括 g, o, u、target网络中使用的 o_2, g_2 和奖励 r

        # Create network.
        # 根据 state_shape 创建 tf 变量,其中包括 g, o, u, o_2, g_2, r
        # self.buffer_ph_tf = [<tf.Tensor 'ddpg/Placeholder:0' shape=(?, 3) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_1:0' shape=(?, 10) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_2:0' shape=(?, 4) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_3:0' shape=(?, 10) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_4:0' shape=(?, 3) dtype=float32>,
        #                     <tf.Tensor 'ddpg/Placeholder_5:0' shape=(?,) dtype=float32>]
        with tf.variable_scope(self.scope):
            # 创建 StagingArea 变量
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            # 创建 Tensorflow 变量 placeholder
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            # 将 tensorflow 变量与 StagingArea 变量相互对应
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)
            #
            self._create_network(reuse=reuse)

        # 经验池相关操作
        # 当T = 50时,执行结束后 buffer_shapes=
        #         {'o': (51, 10), 'u': (50, 4), 'g': (50, 3), 'info_is_success': (50, 1), 'ag': (51, 3)}
        # 注意 a,g,u 均记录一个周期内经历的所有样本,因此为 50 维,但 o 和 ag 需要多1维 ????
        buffer_shapes = {
            key: (self.T if key != 'o' else self.T + 1, *input_shapes[key])
            for key, val in input_shapes.items()
        }  #
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)  #
        buffer_shapes['ag'] = (self.T + 1, self.dimg)  #
        # print("+++", buffer_shapes)

        # buffer_size 是按照样本进行计数的长度
        # self.buffer_size=1E6  self.rollout_batch_size=2 buffer_size=1E6
        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

    def _random_action(self, n):
        """
            从 [-self.max_u, +self.max_u] 中随机采样 n 个动作
        """
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        """
            obs, goal, achieve_goal 进行预处理
            如果 self.relative_goal=True,则 goal = goal - achieved_goal
        """
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)  # 增加1维
            ag = ag.reshape(-1, self.dimg)  # 增加1维
            g = self.subtract_goals(g, ag)  # g = g - ag
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        """
            根据 self.main 网络选择动作,随后添加高斯噪声,clip,epsilon-greedy操作,输出处理后的动作
        """
        # 如果 self.relative_goal=True,则对 goal 进行预处理. 否则只进行 clip
        o, g = self._preprocess_og(o, ag, g)
        # 在调用本类的函数 self._create_network 后,创建了 self.main 网络和 self.target 网络,均为 ActorCritic 对象
        policy = self.target if use_target_net else self.main  # 根据 self.main 选择动作
        # actor 网络输出的动作的 tensor
        vals = [policy.pi_tf]

        # print("+++")
        # print(vals.shape)

        # 将 actor 输出的 vals 再次输入到 critic 网络中,获得输出为 Q_pi_tf
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed_dict的构建,包括 obs, goal 和 action,作为 Actor和Critic的输入
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }
        # 执行当前的策略网络,输出ret.  ret[0]代表action, ret[1]代表Q值
        ret = self.sess.run(vals, feed_dict=feed)

        # action postprocessing
        # 对Action添加高斯噪声. np.random.randn 指从一个高斯分布中进行采样,噪声服从高斯分布
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)  # 添加噪声后进行 clip

        # 进行 epsilon-greedy 操作,epsilon为random_eps
        # np.random.binomial指二项分布,输出的结果是0或1,其中输出为1的概率为 random_eps.
        # 如果二项分布输出0,则 u+=0相当于没有操作;如果输出为1,则 u = u + (random_action - u) = random_action
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u
        #
        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True, verbose=False):
        """
            episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
            调用 replay_buffer 中的 store_episode 函数对一个采样周期的样本进行存储
            o_stats 和 g_stats 分别更新和存储 obs 和 goal 的均值和标准差,并定期更新
        """

        # episode_batch 存储了 rollout.py 中 generate_rollout 产生的一个周期样本
        # episode_batch 是一个字典,键包括 o, g, u, ag, info,值的shape分别为
        #      o (2, 51, 10), u (2, 50, 4), g (2, 50, 3), ag (2, 51, 3), info_is_success (2, 50, 1)
        # 其中第1维是 worker的数目,第2维由周期长度决定

        self.buffer.store_episode(episode_batch, verbose=verbose)

        # 更新 o_stats 和 g_stats 的均值和标准差
        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch[
                'ag'][:, 1:, :]  # 提取出 next_obs 和 next_state
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)  # 将周期转换为总样本数

            # 调用 sample_transitions 中的采样函数
            # episode_batch是一个字典,键和元素shape分别为 o (2, 51, 10) u (2, 50, 4) g (2, 50, 3) ag (2, 51, 3) info_is_success (2, 50, 1)
            #                                          o_2 (2, 50, 10)  ag_2 (2, 50, 3)
            # num_normalizing_transitions=100,原有是有 2 个 worker,每个 worker 含有1个周期的50个样本
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            # 采样出的样本经过预处理后,用于更新计算 o_stats 和 g_stats,定义在Normalizer中,用于存储 mean 和 std
            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        """
            返回当前经验池的样本数量
        """
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        """
            Q_adam 和 pi_adam 为更新 actor网络 和 critic网络的运算符
        """
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        """
            返回损失函数和梯度
            Q_loss_tf, main.Q_pi_tf, Q_grad_tf, pi_grad_tf 均定义在 _create_network 函数中
        """

        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf,
            self.main.Q_pi_tf,
            self.Q_grad_tf,
            self.pi_grad_tf,
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        """
            更新 main 的 Actor 和 Critic 网络
            更新的 op 均定义在 _create_network 中
        """
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        """
            调用 replay_buffer.py 中的 sample 函数进行采样,后者调用的采样方法来源于 her.py 中的定义
            返回的样本组成 batch,用于 self.stage_batch 函数中构建 feed_dict
            feed_dict将作为 选择动作 和 更新网络参数 的输入

            调用采样一个批量的样本,随后对 o 和 g 进行预处理. 样本的 key 包括 o, o_2, ag, ag_2, g
        """
        # 调用sample后返回transition为字典, key 和 val.shape:
        # o (256, 10) u (256, 4) g (256, 3) info_is_success (256, 1) ag (256, 3) o_2 (256, 10) ag_2 (256, 3) r (256,)
        # print("In DDPG: ", self.batch_size)
        transitions = self.buffer.sample(self.batch_size)
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))
        # tensorboard可视化
        self.tfboard_sample_batch = batch
        self.tfboard_sample_tf = self.buffer_ph_tf

    def train(self, stage=True):
        """
            计算梯度,随后更新
            train 中执行参数更新之前先执行了 self.stage_batch,用于构建训练使用的feed_dict. 该函数中调用了 
                    self.sample_batch 函数,后者又调用了 self.buffer.sample,后者调用了 config.py 中的 config_her, 后者对 her.py 的函数进行参数配置.
            train 中的运算符在 self._create_network 中定义.
        """
        if stage:
            self.stage_batch()  # 返回使用 her.py 的采样方式构成的 feed_dict 用于计算梯度
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        """
            更新 target 网络,update_target_net_op 定义在函数 _create_network 中
        """
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        """
            定义计算 Actor 和 Critic 损失所需要的计算流图
        """
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()
        # running averages
        # 分别定义用于规约 obs 和 goal 的 Normalizer 对象
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        # 用于存储一个批量样本的数据结构,为OrderedDict,执行后 batch_tf 如下:
        # OrderedDict([('g', <tf.Tensor 'ddpg/ddpg/StagingArea_get:0' shape=(?, 3) dtype=float32>),
        #              ('o', <tf.Tensor 'ddpg/ddpg/StagingArea_get:1' shape=(?, 10) dtype=float32>),
        #              ('u', <tf.Tensor 'ddpg/ddpg/StagingArea_get:2' shape=(?, 4) dtype=float32>),
        #              ('o_2', <tf.Tensor 'ddpg/ddpg/StagingArea_get:3' shape=(?, 10) dtype=float32>),
        #              ('g_2', <tf.Tensor 'ddpg/ddpg/StagingArea_get:4' shape=(?, 3) dtype=float32>),
        #              ('r', <tf.Tensor 'ddpg/Reshape:0' shape=(?, 1) dtype=float32>)])
        # 定义的 batch_tf 变量将作为神经网络的输入

        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        #
        # 根据 ActorCritic.py 创建 main network
        # 在创建 ActorCritic 网络时,不需要显式的传参,利用 self.__dict__将DDPG类的对应参数直接赋值给 ActorCritic 的对应参数
        # print(self.main.__dict__)
        # {'inputs_tf': OrderedDict([('g', <tf.Tensor 'ddpg/ddpg/StagingArea_get:0' shape=(?, 3) dtype=float32>), ('o', <tf.Tensor 'ddpg/ddpg/StagingArea_get:1' shape=(?, 10) dtype=float32>), ('u', <tf.Tensor 'ddpg/ddpg/StagingArea_get:2' shape=(?, 4) dtype=float32>), ('o_2', <tf.Tensor 'ddpg/ddpg/StagingArea_get:3' shape=(?, 10) dtype=float32>), ('g_2', <tf.Tensor 'ddpg/ddpg/StagingArea_get:4' shape=(?, 3) dtype=float32>), ('r', <tf.Tensor 'ddpg/Reshape:0' shape=(?, 1) dtype=float32>)]),
        # 'net_type': 'main', 'reuse': False, 'buffer_size': 1000000, 'hidden': 256, 'layers': 3, 'network_class': 'GHER.actor_critic:ActorCritic',
        # 'polyak': 0.95, 'batch_size': 256, 'Q_lr': 0.001, 'pi_lr': 0.001, 'norm_eps': 0.01, 'norm_clip': 5, 'max_u': 1.0,
        # 'action_l2': 1.0, 'clip_obs': 200.0, 'scope': 'ddpg', 'relative_goals': False, 'input_dims': {'o': 10, 'u': 4, 'g': 3, 'info_is_success': 1},
        # 'T': 50, 'clip_pos_returns': True, 'clip_return': 49.996, 'rollout_batch_size': 2, 'subtract_goals': <function simple_goal_subtract at 0x7fcf72caa510>, 'sample_transitions': <function make_sample_her_transitions.<locals>._sample_her_transitions at 0x7fcf6e2ce048>,
        # 'gamma': 0.98, 'info': {'env_name': 'FetchReach-v1'}, 'use_mpi': True, 'create_actor_critic': <class 'GHER.actor_critic.ActorCritic'>,
        # 'dimo': 10, 'dimg': 3, 'dimu': 4, 'stage_shapes': OrderedDict([('g', (None, 3)), ('o', (None, 10)), ('u', (None, 4)), ('o_2', (None, 10)), ('g_2', (None, 3)), ('r', (None,))]), 'staging_tf': <tensorflow.python.ops.data_flow_ops.StagingArea object at 0x7fcf6e2dddd8>,
        # 'buffer_ph_tf': [<tf.Tensor 'ddpg/Placeholder:0' shape=(?, 3) dtype=float32>, <tf.Tensor 'ddpg/Placeholder_1:0' shape=(?, 10) dtype=float32>, <tf.Tensor 'ddpg/Placeholder_2:0' shape=(?, 4) dtype=float32>, <tf.Tensor 'ddpg/Placeholder_3:0' shape=(?, 10) dtype=float32>, <tf.Tensor 'ddpg/Placeholder_4:0' shape=(?, 3) dtype=float32>, <tf.Tensor 'ddpg/Placeholder_5:0' shape=(?,) dtype=float32>],
        # 'stage_op': <tf.Operation 'ddpg/ddpg/StagingArea_put' type=Stage>, 'sess': <tensorflow.python.client.session.InteractiveSession object at 0x7fcf6e2dde10>, 'o_stats': <GHER.normalizer.Normalizer object at 0x7fcf6e2ee940>, 'g_stats': <GHER.normalizer.Normalizer object at 0x7fcf6e2ee898>,
        # 'o_tf': <tf.Tensor 'ddpg/ddpg/StagingArea_get:1' shape=(?, 10) dtype=float32>, 'g_tf': <tf.Tensor 'ddpg/ddpg/StagingArea_get:0' shape=(?, 3) dtype=float32>, 'u_tf': <tf.Tensor 'ddpg/ddpg/StagingArea_get:2' shape=(?, 4) dtype=float32>, 'pi_tf': <tf.Tensor 'ddpg/main/pi/mul:0' shape=(?, 4) dtype=float32>, 'Q_pi_tf': <tf.Tensor 'ddpg/main/Q/_3/BiasAdd:0' shape=(?, 1) dtype=float32>, '_input_Q': <tf.Tensor 'ddpg/main/Q/concat_1:0' shape=(?, 17) dtype=float32>, 'Q_tf': <tf.Tensor 'ddpg/main/Q/_3_1/BiasAdd:0' shape=(?, 1) dtype=float32>}

        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()

        # o_2, g_2 用来创建 target network
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf[
                'g_2']  # 由于 target 网络用于计算 target-Q 值,因此 o 和 g 需使用下一个状态的值
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        # 计算Critic的target-Q值,需要用到Actor的target网络 和 Critic的target网络
        # target_Q_pi_tf 使用的是下一个状态 o_2 和 g_2
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        # Critic 的损失函数为 target_tf 与 Q_tf 的差的平方,注意梯度不通过target_tf进行传递
        self.Q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))

        # Actor的损失函数为 main 网络中将actor的输出随后输入到critic网络中得到Q值的相反数
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        # Actor中加入正则
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))

        # 计算梯度
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf,
                                   self._vars('main/Q'))  # 梯度和变量名进行对应
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars(
            'main/pi')  # 将Actor和Critic网络的参数放在一起
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(  # target 初始化操作中,main网络参数直接赋值给target
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(  # target 更新操作中,需要将 main 网络和 target 网络按照参数 polyak 加权
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # # tensorboard可视化
        # tf.summary.scalar("Q_target-Q-mean", tf.reduce_mean(target_tf))
        # tf.summary.histogram("Q_target-Q", target_tf)
        # tf.summary.scalar("Q_Td-error-mean", tf.reduce_mean(target_tf - self.main.Q_tf))
        # tf.summary.histogram("Q_Td-error", target_tf - self.main.Q_tf)
        # tf.summary.scalar("Q_reward-mean", tf.reduce_mean(batch_tf['r']))
        # tf.summary.histogram("Q_reward", batch_tf['r'])
        # tf.summary.scalar("Q_loss_tf", self.Q_loss_tf)
        # tf.summary.scalar("pi_loss_tf", self.pi_loss_tf)
        # self.merged = tf.summary.merge_all()

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def tfboard_func(self, summary_writer, step):
        """
            tensorboard可视化
        """
        self.sess.run(self.stage_op,
                      feed_dict=dict(
                          zip(self.tfboard_sample_tf,
                              self.tfboard_sample_batch)))
        summary = self.sess.run(self.merged)
        summary_writer.add_summary(summary, global_step=step)

        print("S" + str(step), end=",")

    def __getstate__(self):
        """
            Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)

    # -----------------------------------------
    def updata_loss_all(self, verbose=False):
        assert self.buffer.current_size > 0
        idxes = np.arange(self.buffer.current_size)
        print("--------------------------------------")
        print("Updata All loss start...")
        self.buffer.update_rnnLoss(idxes, verbose=verbose)
        print("Updata All loss end ...")
示例#19
0
class DDPG(object):
    @store_args
    def __init__(self,
                 use_aux_tasks,
                 input_dims,
                 image_input_shapes,
                 buffer_size,
                 hidden,
                 layers,
                 dim_latent_repr,
                 cnn_nonlinear,
                 use_bottleneck_layer,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 scope,
                 T,
                 rollout_batch_size,
                 clip_pos_returns,
                 clip_return,
                 log_loss,
                 sample_transitions,
                 gamma,
                 rank,
                 serialized=False,
                 reuse=False,
                 clip_grad_range=None,
                 aux_filter_interval=None,
                 scale_grad_by_procs=False,
                 aux_update_interval=5,
                 aux_base_lr=5,
                 **kwargs):
        """ See the documentation in main.py """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(
            'cnn_actor_critic:CNNActorCritic')

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        if self.use_aux_tasks:
            self.dim_bw_frame = self.input_dims['info_bw_frame']
            self.dim_op_flow = self.input_dims['info_op_flow']
            self.dim_transformed_frame = self.input_dims[
                'info_transformed_frame']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()

        include_info = [
            'info_state_obs', 'info_transformed_frame', 'info_transformation',
            'info_op_flow', 'info_bw_frame'
        ]

        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_') and not key in include_info:
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            if self.use_aux_tasks:
                # Initialize OL-AUX
                self.num_auxiliary_tasks = 5
                self.aux_weights_lr = self.aux_base_lr * self.aux_update_interval

                self.aux_weight_vector_Q_tf = tf.Variable(
                    initial_value=1 * tf.ones(self.num_auxiliary_tasks),
                    dtype=tf.float32,
                    name='aux_weights')
                self.aux_weight_grads_buffer = []
                self.log_aux_losses_Q = self.log_aux_tasks_losses_pi = None  # Logging buffer for aux losses
                if self.aux_filter_interval is not None:
                    self.all_grad_history = deque(
                        maxlen=self.aux_filter_interval)

            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=self.reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key:
            (self.T if key != 'o' and not key.startswith('info_') else self.T +
             1, *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T + 1, self.dimg)

        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }
        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = o.copy(), g.copy()
            # No need to preprocess the o_2 and g_2 since this is only used for stats
            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])
            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

            if self.use_aux_tasks:
                self.bw_frame_stats.update(transitions['info_bw_frame'])
                self.op_flow_stats.update(transitions['info_op_flow'])
                self.transformed_frame_stats.update(
                    transitions['info_transformed_frame'])
                self.bw_frame_stats.recompute_stats()
                self.op_flow_stats.recompute_stats()
                self.transformed_frame_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        assert not self.serialized
        run_vars = [
            self.Q_loss_tf, self.pi_loss_tf, self.Q_grad_tf, self.pi_grad_tf
        ]

        if self.use_aux_tasks:
            run_vars.append(self.main_task_Q_cnn_grad_flatten_tf)
            run_vars.extend(
                self.main.loss_auxiliary_tasks_Q_tf)  # Q Aux losses
            run_vars.extend(self.aux_Q_cnn_grads_flatten_tf)  # Q Aux grads
            run_vars.extend(
                self.main.loss_auxiliary_tasks_pi_tf)  # pi Aux losses
            assert len(
                self.aux_Q_cnn_grads_flatten_tf) == self.num_auxiliary_tasks
            rets = self.sess.run(run_vars)

            aux_losses_pi = copy.copy(rets[-self.num_auxiliary_tasks:])
            aux_grads_Q = copy.copy(
                rets[-2 * self.num_auxiliary_tasks:-self.num_auxiliary_tasks])
            aux_losses_Q = copy.copy(rets[-3 * self.num_auxiliary_tasks:-2 *
                                          self.num_auxiliary_tasks])

            rets = rets[:-3 * self.num_auxiliary_tasks] + [aux_losses_pi] + [
                aux_losses_Q
            ] + [aux_grads_Q]
        else:
            rets = self.sess.run(run_vars)
        return rets

    # noinspection PyAttributeOutsideInit
    def train(self, stage=True):
        # import cProfile, pstats, io
        # pr = cProfile.Profile()
        # pr.enable()
        if stage:
            self.stage_batch()
        if self.use_aux_tasks:
            critic_loss, actor_loss, Q_grad, pi_grad, main_task_grad, \
            aux_losses_pi, aux_losses_Q, aux_task_grads_Q = self._grads()

            self.log_aux_losses_Q = [loss for loss in aux_losses_Q]
            self.log_aux_losses_pi = [loss for loss in aux_losses_pi]

            self._update(Q_grad, pi_grad)
            self._update_aux_weights(main_task_grad, aux_task_grads_Q)
        else:
            critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
            self._update(Q_grad, pi_grad)
        # pr.disable()
        # s = io.StringIO()
        # ps = pstats.Stats(pr, stream=s).sort_stats('time')
        # ps.print_stats(20)
        # print(s.getvalue())

        return critic_loss, actor_loss

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        transitions = self.buffer.sample(self.batch_size)
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        transitions['o'], transitions['g'] = o.copy(), g.copy()
        transitions['o_2'], transitions['g_2'] = o_2.copy(), g.copy()

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def _update_aux_weights(self, main_task_grad, aux_task_grads):
        """
        Called during each iteration. But only update the auxiliary task weights according to the update interval
        :param main_task_grad: Gradient of the main task (of cnn)
        :param aux_task_grads: A list of the gradients from each of the auxiliary tasks (of cnn)
        """
        main_task_grad, aux_task_grads = self.aux_weight_updater.get_syncd_grad(
            main_task_grad, aux_task_grads)

        aux_weight_grad = np.zeros([self.num_auxiliary_tasks])
        aux_task_grads = np.array(aux_task_grads)
        main_task_grad = np.array(main_task_grad)

        if self.aux_filter_interval is not None:
            self.all_grad_history.append(
                (main_task_grad.copy(), aux_task_grads.copy()))
            main_task_grad = np.mean(np.array(
                [grad[0] for grad in self.all_grad_history]),
                                     axis=0)
            aux_task_grads = np.mean(np.array(
                [grad[1] for grad in self.all_grad_history]),
                                     axis=0)

        for i, aux_task_grad in enumerate(aux_task_grads):
            aux_weight_grad[i] = self.Q_lr * np.dot(aux_task_grad,
                                                    main_task_grad)
        self.aux_weight_grads_buffer.append(aux_weight_grad)

        if len(self.aux_weight_grads_buffer) == self.aux_update_interval:
            aggregate_aux_weight_grad = np.mean(np.array(
                self.aux_weight_grads_buffer),
                                                axis=0)
            self.aux_weight_updater.update(self.aux_weights_lr *
                                           aggregate_aux_weight_grad)
            self.aux_weight_grads_buffer = []

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dimu, self.max_u))
        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()
        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        if self.use_aux_tasks:
            with tf.variable_scope('bw_frame_stats') as vs:
                if reuse:
                    vs.reuse_variables()
                self.bw_frame_stats = Normalizer(self.dim_bw_frame,
                                                 self.norm_eps,
                                                 self.norm_clip,
                                                 sess=self.sess)

            with tf.variable_scope('op_flow_stats') as vs:
                if reuse:
                    vs.reuse_variables()
                self.op_flow_stats = Normalizer(self.dim_op_flow,
                                                self.norm_eps,
                                                self.norm_clip,
                                                sess=self.sess)

            with tf.variable_scope('transformed_frame_stats') as vs:
                if reuse:
                    vs.reuse_variables()
                self.transformed_frame_stats = Normalizer(
                    self.dim_transformed_frame,
                    self.norm_eps,
                    self.norm_clip,
                    sess=self.sess)

        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            if self.use_aux_tasks:
                self.main.build_auxiliary_tasks()
            vs.reuse_variables()

        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)

        if self.use_aux_tasks and self.log_loss:
            self.pi_loss_tf = tf.clip_by_value(
                self.pi_loss_tf,
                np.finfo(float).eps, np.Inf)  # So that log can be applied
            self.Q_loss_tf = tf.log(self.Q_loss_tf)
            self.pi_loss_tf = tf.log(self.pi_loss_tf)

        self.action_l2_loss_tf = self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))
        self.pi_loss_tf += self.action_l2_loss_tf

        if self.use_aux_tasks:
            if self.log_loss:
                for i, loss_tf in enumerate(
                        self.main.loss_auxiliary_tasks_Q_tf):
                    self.Q_loss_tf += tf.stop_gradient(
                        self.aux_weight_vector_Q_tf[i]) * tf.log(
                            loss_tf + self.log_min_loss)

                # Use the same weight of the auxiliary tasks in Q function also for pi.
                # Also possible to use separate aux weight vectors for Q and pi
                for i, loss_tf in enumerate(
                        self.main.loss_auxiliary_tasks_pi_tf):
                    self.pi_loss_tf += tf.stop_gradient(
                        self.aux_weight_vector_Q_tf[i]) * tf.log(
                            loss_tf + self.log_min_loss)
            else:
                for i, loss_tf in enumerate(
                        self.main.loss_auxiliary_tasks_Q_tf):
                    self.Q_loss_tf += tf.stop_gradient(
                        self.aux_weight_vector_Q_tf[i]) * loss_tf
                for i, loss_tf in enumerate(
                        self.main.loss_auxiliary_tasks_pi_tf):
                    self.pi_loss_tf += tf.stop_gradient(
                        self.aux_weight_vector_Q_tf[i]) * loss_tf

        Q_grads_tf = tf.gradients(self.Q_loss_tf,
                                  self._vars('main/Q'),
                                  name='Q_gradient')
        pi_grads_tf = tf.gradients(self.pi_loss_tf,
                                   self._vars('main/pi'),
                                   name='pi_gradient')

        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'),
                                       clip_grad_range=self.clip_grad_range)
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'),
                                        clip_grad_range=self.clip_grad_range)

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'),
                              scale_grad_by_procs=self.scale_grad_by_procs)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=self.scale_grad_by_procs)
        if self.use_aux_tasks:
            self.aux_weight_updater = MpiAuxUpdate(self._vars('aux_weights'),
                                                   scale_grad_by_procs=True)

        if self.use_aux_tasks:
            # Get gradient from the auxiliary tasks w.r.t. the shared cnn
            if self.log_loss:
                aux_Q_cnn_grads_tf = [
                    tf.gradients(
                        tf.log(loss_tf + self.log_min_loss, name=loss_name),
                        self._vars('main/Q/cnn'))
                    for (loss_tf,
                         loss_name) in zip(self.main.loss_auxiliary_tasks_Q_tf,
                                           self.main.name_auxiliary_tasks)
                ]
            else:
                aux_Q_cnn_grads_tf = [
                    tf.gradients(loss_tf, self._vars('main/Q/cnn'))
                    for loss_tf in self.main.loss_auxiliary_tasks_Q_tf
                ]
            self.aux_Q_cnn_grads_flatten_tf = [
                flatten_grads(grads=aux_grad_tf,
                              var_list=self._vars('main/Q/cnn'),
                              clip_grad_range=self.clip_grad_range)
                for aux_grad_tf in aux_Q_cnn_grads_tf
            ]

            # Get gradient of cnn from the main task
            self.main_task_Q_cnn_grad_tf = tf.gradients(
                self.Q_loss_tf,
                self._vars('main/Q/cnn'),
                name='aux_update_main_gradient_Q')
            self.main_task_Q_cnn_grad_flatten_tf = flatten_grads(
                grads=self.main_task_Q_cnn_grad_tf,
                var_list=self._vars('main/Q/cnn'),
                clip_grad_range=self.clip_grad_range)

        # polyak averaging, excluding the auxiliary variables
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.main_vars = [
            var for var in self.main_vars
            if var not in (self._vars('main/Q/aux') +
                           self._vars('main/pi/aux'))
        ]
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        assert len(self.main_vars) == len(self.target_vars)
        self.stats_vars = self._global_vars('o_stats') + self._global_vars('bw_frame_stats') + \
                          self._global_vars('op_flow_stats') + self._global_vars('g_stats') + \
                          self._global_vars('transformed_frame_stats')

        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        tf.variables_initializer(self._global_vars('')).run()

        self._sync_optimizers()
        if self.use_aux_tasks:
            self.aux_weight_updater.sync()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        transitions = self.buffer.sample(self.batch_size)
        action_mean = np.mean(np.abs(transitions['u']))
        action_std = np.std(transitions['u'])
        logs += [('buffer_a/abs_mean', action_mean)]
        logs += [('buffer_a/std', action_std)]

        if self.use_aux_tasks:
            # Log auxiliary task losses (After the log operator)
            for (aux_task_name,
                 aux_task_weight) in zip(self.main.name_auxiliary_tasks,
                                         self.log_aux_losses_Q):
                logs += [('aux_losses_Q/' + aux_task_name, aux_task_weight)]

            # Log auxiliary task weights
            curr_aux_weights = self.sess.run(self.aux_weight_vector_Q_tf)
            for (aux_task_name,
                 aux_task_weight) in zip(self.main.name_auxiliary_tasks,
                                         curr_aux_weights):
                logs += [('aux_weights_Q/' + aux_task_name, aux_task_weight)]
        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', '_updater', 'buffer', 'sess',
            '_stats', 'main', 'target', 'lock', 'env', 'sample_transitions',
            'stage_shapes', 'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None
        state['serialized'] = True
        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
示例#20
0
class DDPG(object):
    @store_args
    def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
                 Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
                 rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
                 sample_transitions, gamma, replay_k, reward_fun=None, reuse=False, **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        # Create the actor critic networks. network_class is defined in actor_critic.py
        # This class is assigned to network_class when DDPG objest is created
        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        # Next state (o_2) and goal at next state (g_2)
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None,)
        self.stage_shapes = stage_shapes

        # Adding variable for correcting bias - Ameet
        self.stage_shapes_new = OrderedDict()
        self.stage_shapes_new['bias'] = (None,)
        ##############################################

        # Create network
        # Staging area is a datatype in tf to input data into GPUs
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)
            
            # Adding bias term from section 3.4 - Ameet
            self.staging_tf_new = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes_new.keys()],
                shapes=list(self.stage_shapes_new.values()))
            self.buffer_ph_tf_new = [
                tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes_new.values()]
            self.stage_op_new = self.staging_tf_new.put(self.buffer_ph_tf_new)
            ############################################

            self._create_network(reuse=reuse)

        # Configure the replay buffer
        buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *input_shapes[key])
                         for key, val in input_shapes.items()}
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T+1, self.dimg)

        buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size

        # conf represents the parameters required for initializing the priority_queue
        # Remember: The bias gets annealed only conf.total_steps number of times
        conf = {'size': self.buffer_size,
                'learn_start': self.batch_size,
                'batch_size': self.batch_size,
                # Using some heuristic to set the partition_num as it matters only when the buffer is not full (unlikely)
                'partition_size': (self.replay_k)*100}

        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions, conf, self.replay_k)

        # global_steps represents the number of batches used for updates
        self.global_step = 0
        self.debug = {}

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))

    # Preprocessing by clipping the goal and state variables
    # Not sure about the relative_goal part
    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    # target is the target policy network and main is the one which is updated
    # target is updated by moving the parameters towards that of the main
    # pi_tf is the output of the policy network, Q_pi_tf is the output of the Q network used for training pi_tf
    # i.e., Q_pi_tf uses the pi_tf's action to evaluate the value 
    # While just Q_tf uses the action which was actually taken
    def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf: o.reshape(-1, self.dimo),
            policy.g_tf: g.reshape(-1, self.dimg),
            policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(*u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        ###### Remove the l value - Supposed to be a list of length 2
        # First entry consists of transitions with actual goals and second is alternate goals
        self.buffer.store_episode(episode_batch)

        # ###### Debug
        # # This functions was used to check the hypothesis that if TD error is high
        # # for a state with some goal, it is high for that states with all other goals
        # self.debug_td_error_alternate_actual(debug_transitions)


        # Updating stats

        ## Change this--------------
        update_stats = False
        ###--------------------------
        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
            transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()


    # This function is purely for Debugging purposes
    def debug_td_error_alternate_actual(self, debug_transitions):
        actual_transitions, alternate_transitions = debug_transitions[0], debug_transitions[1]
        actual_transitions, alternate_transitions = self.td_error_convert_to_format(actual_transitions),\
                                                    self.td_error_convert_to_format(alternate_transitions)

        # Calculated priorities
        priorities = []
        priorities.append(self.get_priorities(actual_transitions))
        priorities.append(self.get_priorities(alternate_transitions))

        f = open('act_alt_goals.txt', 'a')

        # Length of priorities[0] is 100 and priorities[1] is 400
        for i in range(len(priorities[0])):
            f.write(str(priorities[0][i])+" : ")
            for k in range(4):
                f.write(str(priorities[1][i*self.replay_k+k])+" : ")
            f.write('\n')

        f.write("Done Storing One Rollout\n\n\n")
        # f.write('The number of transitions are: '+str(len(priorities[0]))+" :: "+str(len(priorities[1]))+"\n")


    # This function is purely for Debugging purposes
    def td_error_convert_to_format(self, sample_transitions):
        # sample_transitions is now a list of transitions, convert it to the usual {key: batch X dim_key}
        keys = sample_transitions[0].keys()
        # print("Keys in _sample_her_transitions are: "+str(keys))
        transitions = {}
        for key in keys:
            # Initialize for all the keys
            transitions[key] = []

            # Add transitions one by one to the list
            for single_transition in range(len(sample_transitions)):
                transitions[key].append(sample_transitions[single_transition][key])
            transitions[key] = np.array(transitions[key])
        
        # Reconstruct info dictionary for reward  computation.
        info = {}
        for key, value in transitions.items():
            if key.startswith('info_'):
                info[key.replace('info_', '')] = value

        # print("The keys in transitions are: "+str(transitions.keys()))
        reward_params = {k: transitions[k] for k in ['ag_2', 'g']}
        reward_params['info'] = info
        transitions['r'] = self.reward_fun(**reward_params)

        # transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:])
        #                for k in transitions.keys()}


        return transitions

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf,
            self.main.Q_pi_tf,
            self.Q_grad_tf,
            self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    # Adam update for Q and pi networks
    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    # Sample a batch for mini batch gradient descent, already defined in replay_buffer.py
    def sample_batch(self):
        # Increment the global step
        self.global_step += 1

        transitions, w, rank_e_id = self.buffer.sample(self.batch_size, self.global_step, self.uniform_priority)
        priorities = self.get_priorities(transitions)

        # ##### Debug function
        # self.debug_td_error(transitions, priorities)
        # #####
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g)

        # # Remove
        # print("Stage Shape keys in sample_batch are: "+str(self.stage_shapes.keys()))

        transitions_batch = [transitions[key] for key in self.stage_shapes.keys()]

        # Updates the priorities of the sampled transitions in the priority queue
        self.buffer.update_priority(rank_e_id, priorities)

        return transitions_batch, [w]


    # This function is purely for debugging purposes
    def debug_td_error(self, transitions, priorities):
        f = open('td_error_debug.txt', 'a')
        self.debug['actual_goals'] = 0
        self.debug['alternate_goals'] = 0
        trans = transitions['is_actual_goal']
        for t in range(trans.shape[0]):
            if trans[t]:
                self.debug['actual_goals'] += 1
                # f.write('Actual goal transition: '+str(priorities[t])+'\n')
            else:
                self.debug['alternate_goals'] += 1
                # f.write('Alternate goal transition: '+str(priorities[t])+'\n')
        f.write('Ratio is: '+str(float(self.debug['alternate_goals'])/self.debug['actual_goals'])+'\n')
        del transitions['is_actual_goal']

    ###### Debug End

    def get_priorities(self, transitions):
        pi_target = self.target.pi_tf
        Q_pi_target = self.target.Q_pi_tf
        Q_main = self.main.Q_tf


        o = transitions['o']
        o_2 = transitions['o_2']
        u = transitions['u']
        g = transitions['g']
        r = transitions['r']
        # Check this with Srikanth
        ag = transitions['ag']

        priorities = np.zeros(o.shape[0])

        # file_obj = open("priorities_print","a")
        for i in range(o.shape[0]):
            o_2_i = np.clip(o_2[i], -self.clip_obs, self.clip_obs)
            o_i, g_i = self._preprocess_og(o[i], ag[i], g[i])
            u_i = u[i]

            # Not sure about the o_2_i.size // self.dimo. I guess we need not pass one at a time
            feed_target = {
                self.target.o_tf: o_2_i.reshape(-1, self.dimo),
                self.target.g_tf: g_i.reshape(-1, self.dimg),
                self.target.u_tf: np.zeros((o_2_i.size // self.dimo, self.dimu), dtype=np.float32)
            }

            # u_tf for main network is just the action taken at that state
            feed_main = {
                self.main.o_tf: o_i.reshape(-1, self.dimo),
                self.main.g_tf: g_i.reshape(-1, self.dimg),
                self.main.u_tf: u_i.reshape(-1, self.dimu)
            }

            TD = r[i] + self.gamma*self.sess.run(Q_pi_target, feed_dict=feed_target) - self.sess.run(Q_main, feed_dict=feed_main)

            priorities[i] = abs(TD)

            text = str(TD)
            # file_obj.write(text)
        # file_obj.close()

        return priorities


    def stage_batch(self, batch=None):
        if batch is None:
            batch, bias = self.sample_batch()
            # print("Batch type is: "+str(type(batch)))
            # print("Batch Shape is: "+str(len(batch)))
            # print(str(type(batch[0])))
        assert len(self.buffer_ph_tf) == len(batch), "Expected: "+str(len(self.buffer_ph_tf))+" Got: "+str(len(batch))
        self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))

        ##### Adding for bias - Ameet
        assert len(self.buffer_ph_tf_new) == len(bias), "Expected: "+str(len(self.buffer_ph_tf_new))+" Got: "+str(len(bias))
        self.sess.run(self.stage_op_new, feed_dict=dict(zip(self.buffer_ph_tf_new, bias)))
        #####
        
        # print("Completed stage batch")

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        # print("In ddpg priority:: The shapes of Q_grad and pi_grad are: "+str(Q_grad.shape)+"::"+str(pi_grad.shape))
        # print("Their types are::"+str(type(Q_grad)))
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([(key, batch[i])
                                for i, key in enumerate(self.stage_shapes.keys())])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        ########### Getting the bias terms - Ameet
        bias = self.staging_tf_new.get()
        bias_tf = OrderedDict([(key, bias[i])
                                for i, key in enumerate(self.stage_shapes_new.keys())])
        bias_tf['bias'] = tf.reshape(bias_tf['bias'], [-1, 1])
        #######################################

        # Create main and target networks, each will have a pi_tf, Q_tf and Q_pi_tf
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(
                target_batch_tf, net_type='target', **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        ############## Added for bias - Ameet
        error = (tf.stop_gradient(target_tf) - self.main.Q_tf) * bias_tf['bias']
        self.Q_loss_tf = tf.reduce_mean(tf.square(error))
        # self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf * bias_tf['bias'])
        # Note that the following statement does not include bias because of the remark in the IEEE paper
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        ##############
        # Regularization - L2 - Check - Penalty for taking the best action
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        ################### Shape Info
        ####Shape of Q_grads_tf is: 8
        ####Shape of Q_grads_tf[0] is: (17, 256)
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)

        # polyak averaging
        # 'main/Q' is a way of communicating the scope of the variables
        # _vars has a way to understand this
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
        # Update the networks
        # target net is updated by using polyak averaging
        # target net is initialized by just copying the main net
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats',
                             'main', 'target', 'lock', 'env', 'sample_transitions',
                             'stage_shapes', 'create_actor_critic']

        state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])}
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert(len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
示例#21
0
class DDPG(object):
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 time_horizon,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 sample_transitions,
                 gamma,
                 reuse=False):
        """
        Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        :param input_dims: ({str: int}) dimensions for the observation (o), the goal (g), and the actions (u)
        :param buffer_size: (int) number of transitions that are stored in the replay buffer
        :param hidden: (int) number of units in the hidden layers
        :param layers: (int) number of hidden layers
        :param network_class: (str) the network class that should be used (e.g. 'baselines.her.ActorCritic')
        :param polyak: (float) coefficient for Polyak-averaging of the target network
        :param batch_size: (int) batch size for training
        :param q_lr: (float) learning rate for the Q (critic) network
        :param pi_lr: (float) learning rate for the pi (actor) network
        :param norm_eps: (float) a small value used in the normalizer to avoid numerical instabilities
        :param norm_clip: (float) normalized inputs are clipped to be in [-norm_clip, norm_clip]
        :param max_u: (float) maximum action magnitude, i.e. actions are in [-max_u, max_u]
        :param action_l2: (float) coefficient for L2 penalty on the actions
        :param clip_obs: (float) clip observations before normalization to be in [-clip_obs, clip_obs]
        :param scope: (str) the scope used for the TensorFlow graph
        :param time_horizon: (int) the time horizon for rollouts
        :param rollout_batch_size: (int) number of parallel rollouts per DDPG agent
        :param subtract_goals: (function (numpy Number, numpy Number): numpy Number) function that subtracts goals
            from each other
        :param relative_goals: (boolean) whether or not relative goals should be fed into the network
        :param clip_pos_returns: (boolean) whether or not positive returns should be clipped
        :param clip_return: (float) clip returns to be in [-clip_return, clip_return]
        :param sample_transitions: (function (dict, int): dict) function that samples from the replay buffer
        :param gamma: (float) gamma used for Q learning updates
        :param reuse: (boolean) whether or not the networks should be reused
        """
        # Updated in experiments/config.py
        self.input_dims = input_dims
        self.buffer_size = buffer_size
        self.hidden = hidden
        self.layers = layers
        self.network_class = network_class
        self.polyak = polyak
        self.batch_size = batch_size
        self.q_lr = q_lr
        self.pi_lr = pi_lr
        self.norm_eps = norm_eps
        self.norm_clip = norm_clip
        self.max_u = max_u
        self.action_l2 = action_l2
        self.clip_obs = clip_obs
        self.scope = scope
        self.time_horizon = time_horizon
        self.rollout_batch_size = rollout_batch_size
        self.subtract_goals = subtract_goals
        self.relative_goals = relative_goals
        self.clip_pos_returns = clip_pos_returns
        self.clip_return = clip_return
        self.sample_transitions = sample_transitions
        self.gamma = gamma
        self.reuse = reuse

        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dim_obs = self.input_dims['o']
        self.dim_goal = self.input_dims['g']
        self.dim_action = self.input_dims['u']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.time_horizon if key != 'o' else self.time_horizon + 1,
                  *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dim_goal)
        buffer_shapes['ag'] = (self.time_horizon + 1, self.dim_goal)

        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size,
                                   self.time_horizon, self.sample_transitions)

    def _random_action(self, num):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(num, self.dim_action))

    def _preprocess_obs_goal(self, obs, achieved_goal, goal):
        if self.relative_goals:
            g_shape = goal.shape
            goal = goal.reshape(-1, self.dim_goal)
            achieved_goal = achieved_goal.reshape(-1, self.dim_goal)
            goal = self.subtract_goals(goal, achieved_goal)
            goal = goal.reshape(*g_shape)
        obs = np.clip(obs, -self.clip_obs, self.clip_obs)
        goal = np.clip(goal, -self.clip_obs, self.clip_obs)
        return obs, goal

    def get_actions(self,
                    obs,
                    achieved_goal,
                    goal,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_q=False):
        """
        return the action from an observation and goal

        :param obs: (numpy Number) the observation
        :param achieved_goal: (numpy Number) the achieved goal
        :param goal: (numpy Number) the goal
        :param noise_eps: (float) the noise epsilon
        :param random_eps: (float) the random epsilon
        :param use_target_net: (bool) whether or not to use the target network
        :param compute_q: (bool) whether or not to compute Q value
        :return: (numpy float or float) the actions
        """
        obs, goal = self._preprocess_obs_goal(obs, achieved_goal, goal)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_q:
            vals += [policy.q_pi_tf]
        # feed
        feed = {
            policy.o_tf:
            obs.reshape(-1, self.dim_obs),
            policy.g_tf:
            goal.reshape(-1, self.dim_goal),
            policy.u_tf:
            np.zeros((obs.size // self.dim_obs, self.dim_action),
                     dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        action = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *action.shape)  # gaussian noise
        action += noise
        action = np.clip(action, -self.max_u, self.max_u)
        # eps-greedy
        n_ac = action.shape[0]
        action += np.random.binomial(1, random_eps, n_ac).reshape(
            -1, 1) * (self._random_action(n_ac) - action)
        if action.shape[0] == 1:
            action = action[0]
        action = action.copy()
        ret[0] = action

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        Story the episode transitions

        :param episode_batch: (numpy Number) array of batch_size x (T or T+1) x dim_key 'o' is of size T+1,
            others are of size T
        :param update_stats: (bool) whether to update stats or not
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            obs, _, goal, achieved_goal = transitions['o'], transitions[
                'o_2'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_obs_goal(
                obs, achieved_goal, goal)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        """
        returns the current buffer size

        :return: (int) buffer size
        """
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, q_grad, pi_grad = self.sess.run([
            self.q_loss_tf, self.main.q_pi_tf, self.q_grad_tf, self.pi_grad_tf
        ])
        return critic_loss, actor_loss, q_grad, pi_grad

    def _update(self, q_grad, pi_grad):
        self.q_adam.update(q_grad, self.q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        """
        sample a batch

        :return: (dict) the batch
        """
        transitions = self.buffer.sample(self.batch_size)
        obs, obs_2, goal = transitions['o'], transitions['o_2'], transitions[
            'g']
        achieved_goal, achieved_goal_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_obs_goal(
            obs, achieved_goal, goal)
        transitions['o_2'], transitions['g_2'] = self._preprocess_obs_goal(
            obs_2, achieved_goal_2, goal)

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        """
        apply a batch to staging

        :param batch: (dict) the batch to add to staging, if None: self.sample_batch()
        """
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        """
        train DDPG

        :param stage: (bool) enable staging
        :return: (float, float) critic loss, actor loss
        """
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, q_grad, pi_grad = self._grads()
        self._update(q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        """
        update the target network
        """
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        """
        clears the replay buffer
        """
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dim_action, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as scope:
            if reuse:
                scope.reuse_variables()
            self.o_stats = Normalizer(self.dim_obs,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as scope:
            if reuse:
                scope.reuse_variables()
            self.g_stats = Normalizer(self.dim_goal,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        with tf.variable_scope('main') as scope:
            if reuse:
                scope.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            scope.reuse_variables()
        with tf.variable_scope('target') as scope:
            if reuse:
                scope.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            scope.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_q_pi_tf = self.target.q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_q_pi_tf, *clip_range)

        self.q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.q_tf))
        self.pi_loss_tf = -tf.reduce_mean(self.main.q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))

        q_grads_tf = tf.gradients(self.q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))

        assert len(self._vars('main/Q')) == len(q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)

        self.q_grads_vars_tf = zip(q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.q_grad_tf = flatten_grads(grads=q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        """
        create a log dictionary
        :param prefix: (str) the prefix for evey index
        :return: ({str: Any}) the log
        """
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([subname not in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for key, value in state.items():
            if key[-6:] == '_stats':
                self.__dict__[key] = value
        # load TF variables
        _vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert len(_vars) == len(state["tf"])
        node = [tf.assign(var, val) for var, val in zip(_vars, state["tf"])]
        self.sess.run(node)
示例#22
0
class DDPG(object):
    @store_args
    def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
                 Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
                 rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
                 bc_loss, q_filter, num_demo, demo_batch_size, prm_loss_weight, aux_loss_weight,
                 sample_transitions, gamma, reuse=False, **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
            Added functionality to use demonstrations for training to Overcome exploration problem.

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
            bc_loss: whether or not the behavior cloning loss should be used as an auxilliary loss
            q_filter: whether or not a filter on the q value update should be used when training with demonstartions
            num_demo: Number of episodes in to be used in the demonstration buffer
            demo_batch_size: number of samples to be used from the demonstrations buffer, per mpi thread
            prm_loss_weight: Weight corresponding to the primary loss
            aux_loss_weight: Weight corresponding to the auxilliary loss also called the cloning loss
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None,)
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {key: (self.T-1 if key != 'o' else self.T, *input_shapes[key])
                         for key, val in input_shapes.items()}
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T, self.dimg)

        buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)

        global DEMO_BUFFER
        DEMO_BUFFER = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions) #initialize the demo buffer; in the same way as the primary data buffer

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def step(self, obs):
        actions = self.get_actions(obs['observation'], obs['achieved_goal'], obs['desired_goal'])
        return actions, None, None, None


    def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf: o.reshape(-1, self.dimo),
            policy.g_tf: g.reshape(-1, self.dimg),
            policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(*u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def init_demo_buffer(self, demoDataFile, update_stats=True): #function that initializes the demo buffer

        demoData = np.load(demoDataFile) #load the demonstration data from data file
        info_keys = [key.replace('info_', '') for key in self.input_dims.keys() if key.startswith('info_')]
        info_values = [np.empty((self.T - 1, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys]

        demo_data_obs = demoData['obs']
        demo_data_acs = demoData['acs']
        demo_data_info = demoData['info']

        for epsd in range(self.num_demo): # we initialize the whole demo buffer at the start of the training
            obs, acts, goals, achieved_goals = [], [] ,[] ,[]
            i = 0
            for transition in range(self.T - 1):
                obs.append([demo_data_obs[epsd][transition].get('observation')])
                acts.append([demo_data_acs[epsd][transition]])
                goals.append([demo_data_obs[epsd][transition].get('desired_goal')])
                achieved_goals.append([demo_data_obs[epsd][transition].get('achieved_goal')])
                for idx, key in enumerate(info_keys):
                    info_values[idx][transition, i] = demo_data_info[epsd][transition][key]


            obs.append([demo_data_obs[epsd][self.T - 1].get('observation')])
            achieved_goals.append([demo_data_obs[epsd][self.T - 1].get('achieved_goal')])

            episode = dict(o=obs,
                           u=acts,
                           g=goals,
                           ag=achieved_goals)
            for key, value in zip(info_keys, info_values):
                episode['info_{}'.format(key)] = value

            episode = convert_episode_to_batch_major(episode)
            global DEMO_BUFFER
            DEMO_BUFFER.store_episode(episode) # create the observation dict and append them into the demonstration buffer
            logger.debug("Demo buffer size currently ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size

            if update_stats:
                # add transitions to normalizer to normalize the demo data as well
                episode['o_2'] = episode['o'][:, 1:, :]
                episode['ag_2'] = episode['ag'][:, 1:, :]
                num_normalizing_transitions = transitions_in_episode_batch(episode)
                transitions = self.sample_transitions(episode, num_normalizing_transitions)

                o, g, ag = transitions['o'], transitions['g'], transitions['ag']
                transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
                # No need to preprocess the o_2 and g_2 since this is only used for stats

                self.o_stats.update(transitions['o'])
                self.g_stats.update(transitions['g'])

                self.o_stats.recompute_stats()
                self.g_stats.recompute_stats()
            episode.clear()

        logger.info("Demo buffer size: ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
            transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)

            o, g, ag = transitions['o'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf,
            self.main.Q_pi_tf,
            self.Q_grad_tf,
            self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        if self.bc_loss: #use demonstration buffer to sample as well if bc_loss flag is set TRUE
            transitions = self.buffer.sample(self.batch_size - self.demo_batch_size)
            global DEMO_BUFFER
            transitions_demo = DEMO_BUFFER.sample(self.demo_batch_size) #sample from the demo buffer
            for k, values in transitions_demo.items():
                rolloutV = transitions[k].tolist()
                for v in values:
                    rolloutV.append(v.tolist())
                transitions[k] = np.array(rolloutV)
        else:
            transitions = self.buffer.sample(self.batch_size) #otherwise only sample from primary buffer

        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g)

        transitions_batch = [transitions[key] for key in self.stage_shapes.keys()]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))
        self.sess = tf_util.get_session()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([(key, batch[i])
                                for i, key in enumerate(self.stage_shapes.keys())])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        #choose only the demo buffer samples
        mask = np.concatenate((np.zeros(self.batch_size - self.demo_batch_size), np.ones(self.demo_batch_size)), axis = 0)

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(
                target_batch_tf, net_type='target', **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))

        if self.bc_loss ==1 and self.q_filter == 1 : # train with demonstrations and use bc_loss and q_filter both
            maskMain = tf.reshape(tf.boolean_mask(self.main.Q_tf > self.main.Q_pi_tf, mask), [-1]) #where is the demonstrator action better than actor action according to the critic? choose those samples only
            #define the cloning loss on the actor's actions only on the samples which adhere to the above masks
            self.cloning_loss_tf = tf.reduce_sum(tf.square(tf.boolean_mask(tf.boolean_mask((self.main.pi_tf), mask), maskMain, axis=0) - tf.boolean_mask(tf.boolean_mask((batch_tf['u']), mask), maskMain, axis=0)))
            self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(self.main.Q_pi_tf) #primary loss scaled by it's respective weight prm_loss_weight
            self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u)) #L2 loss on action values scaled by the same weight prm_loss_weight
            self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf #adding the cloning loss to the actor loss as an auxilliary loss scaled by its weight aux_loss_weight

        elif self.bc_loss == 1 and self.q_filter == 0: # train with demonstrations without q_filter
            self.cloning_loss_tf = tf.reduce_sum(tf.square(tf.boolean_mask((self.main.pi_tf), mask) - tf.boolean_mask((batch_tf['u']), mask)))
            self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(self.main.Q_pi_tf)
            self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
            self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf

        else: #If  not training with demonstrations
            self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
            self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))

        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix != '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats',
                             'main', 'target', 'lock', 'env', 'sample_transitions',
                             'stage_shapes', 'create_actor_critic']

        state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])}
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert(len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)

    def save(self, save_path):
        tf_util.save_variables(save_path)
示例#23
0
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 bc_loss,
                 q_filter,
                 num_demo,
                 demo_batch_size,
                 prm_loss_weight,
                 aux_loss_weight,
                 sample_transitions,
                 gamma,
                 reuse=False,
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
            Added functionality to use demonstrations for training to Overcome exploration problem.

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
            bc_loss: whether or not the behavior cloning loss should be used as an auxilliary loss
            q_filter: whether or not a filter on the q value update should be used when training with demonstartions
            num_demo: Number of episodes in to be used in the demonstration buffer
            demo_batch_size: number of samples to be used from the demonstrations buffer, per mpi thread
            prm_loss_weight: Weight corresponding to the primary loss
            aux_loss_weight: Weight corresponding to the auxilliary loss also called the cloning loss
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.T - 1 if key != 'o' else self.T, *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T, self.dimg)

        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

        global DEMO_BUFFER
        DEMO_BUFFER = ReplayBuffer(
            buffer_shapes, buffer_size, self.T, self.sample_transitions
        )  #initialize the demo buffer; in the same way as the primary data buffer

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def step(self, obs):
        actions = self.get_actions(obs['observation'], obs['achieved_goal'],
                                   obs['desired_goal'])
        return actions, None, None, None

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def init_demo_buffer(
            self,
            demoDataFile,
            update_stats=True):  #function that initializes the demo buffer

        demoData = np.load(
            demoDataFile)  #load the demonstration data from data file
        info_keys = [
            key.replace('info_', '') for key in self.input_dims.keys()
            if key.startswith('info_')
        ]
        info_values = [
            np.empty((self.T - 1, 1, self.input_dims['info_' + key]),
                     np.float32) for key in info_keys
        ]

        demo_data_obs = demoData['obs']
        demo_data_acs = demoData['acs']
        demo_data_info = demoData['info']

        for epsd in range(
                self.num_demo
        ):  # we initialize the whole demo buffer at the start of the training
            obs, acts, goals, achieved_goals = [], [], [], []
            i = 0
            for transition in range(self.T - 1):
                obs.append(
                    [demo_data_obs[epsd][transition].get('observation')])
                acts.append([demo_data_acs[epsd][transition]])
                goals.append(
                    [demo_data_obs[epsd][transition].get('desired_goal')])
                achieved_goals.append(
                    [demo_data_obs[epsd][transition].get('achieved_goal')])
                for idx, key in enumerate(info_keys):
                    info_values[idx][transition,
                                     i] = demo_data_info[epsd][transition][key]

            obs.append([demo_data_obs[epsd][self.T - 1].get('observation')])
            achieved_goals.append(
                [demo_data_obs[epsd][self.T - 1].get('achieved_goal')])

            episode = dict(o=obs, u=acts, g=goals, ag=achieved_goals)
            for key, value in zip(info_keys, info_values):
                episode['info_{}'.format(key)] = value

            episode = convert_episode_to_batch_major(episode)
            global DEMO_BUFFER
            DEMO_BUFFER.store_episode(
                episode
            )  # create the observation dict and append them into the demonstration buffer
            logger.debug("Demo buffer size currently ",
                         DEMO_BUFFER.get_current_size()
                         )  #print out the demonstration buffer size

            if update_stats:
                # add transitions to normalizer to normalize the demo data as well
                episode['o_2'] = episode['o'][:, 1:, :]
                episode['ag_2'] = episode['ag'][:, 1:, :]
                num_normalizing_transitions = transitions_in_episode_batch(
                    episode)
                transitions = self.sample_transitions(
                    episode, num_normalizing_transitions)

                o, g, ag = transitions['o'], transitions['g'], transitions[
                    'ag']
                transitions['o'], transitions['g'] = self._preprocess_og(
                    o, ag, g)
                # No need to preprocess the o_2 and g_2 since this is only used for stats

                self.o_stats.update(transitions['o'])
                self.g_stats.update(transitions['g'])

                self.o_stats.recompute_stats()
                self.g_stats.recompute_stats()
            episode.clear()

        logger.info("Demo buffer size: ", DEMO_BUFFER.get_current_size()
                    )  #print out the demonstration buffer size

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            o, g, ag = transitions['o'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf, self.main.Q_pi_tf, self.Q_grad_tf, self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        if self.bc_loss:  #use demonstration buffer to sample as well if bc_loss flag is set TRUE
            transitions = self.buffer.sample(self.batch_size -
                                             self.demo_batch_size)
            global DEMO_BUFFER
            transitions_demo = DEMO_BUFFER.sample(
                self.demo_batch_size)  #sample from the demo buffer
            for k, values in transitions_demo.items():
                rolloutV = transitions[k].tolist()
                for v in values:
                    rolloutV.append(v.tolist())
                transitions[k] = np.array(rolloutV)
        else:
            transitions = self.buffer.sample(
                self.batch_size)  #otherwise only sample from primary buffer

        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)
        assert np.array_equal(transitions['g_2'], transitions['g'])

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dimu, self.max_u))
        self.sess = tf_util.get_session()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        #choose only the demo buffer samples
        mask = np.concatenate(
            (np.zeros(self.batch_size - self.demo_batch_size),
             np.ones(self.demo_batch_size)),
            axis=0)

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))

        if self.bc_loss == 1 and self.q_filter == 1:  # train with demonstrations and use bc_loss and q_filter both
            maskMain = tf.reshape(
                tf.boolean_mask(self.main.Q_tf > self.main.Q_pi_tf,
                                mask), [-1]
            )  #where is the demonstrator action better than actor action according to the critic? choose those samples only
            #define the cloning loss on the actor's actions only on the samples which adhere to the above masks
            self.cloning_loss_tf = tf.reduce_sum(
                tf.square(
                    tf.boolean_mask(tf.boolean_mask((self.main.pi_tf), mask),
                                    maskMain,
                                    axis=0) -
                    tf.boolean_mask(tf.boolean_mask((batch_tf['u']), mask),
                                    maskMain,
                                    axis=0)))
            self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(
                self.main.Q_pi_tf
            )  #primary loss scaled by it's respective weight prm_loss_weight
            self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(
                tf.square(self.main.pi_tf / self.max_u)
            )  #L2 loss on action values scaled by the same weight prm_loss_weight
            self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf  #adding the cloning loss to the actor loss as an auxilliary loss scaled by its weight aux_loss_weight

        elif self.bc_loss == 1 and self.q_filter == 0:  # train with demonstrations without q_filter
            self.cloning_loss_tf = tf.reduce_sum(
                tf.square(
                    tf.boolean_mask((self.main.pi_tf), mask) -
                    tf.boolean_mask((batch_tf['u']), mask)))
            self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(
                self.main.Q_pi_tf)
            self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(
                tf.square(self.main.pi_tf / self.max_u))
            self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf

        else:  #If  not training with demonstrations
            self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
            self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
                tf.square(self.main.pi_tf / self.max_u))

        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix != '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)

    def save(self, save_path):
        tf_util.save_variables(save_path)
示例#24
0
文件: ddpg.py 项目: Baichenjia/BHER
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 sample_transitions,
                 gamma,
                 r_bias,
                 bias_clip_low,
                 bias_clip_high,
                 n_epochs,
                 ismuti,
                 reuse=False,
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.T if key != 'o' else self.T + 1, *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T + 1, self.dimg)

        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

        self.total_epoch_r_mean_bias = []
        self.total_epoch_r_std_bias = []
        self.rb = r_bias
        self.bias_clip_low = bias_clip_low
        self.bias_clip_high = bias_clip_high
        self.isMuti = ismuti
        self.epcoch_num = 0
        self.isPlot = False
        self.picdir = ''
        self.rewdir = ''

    def save_reward_pic(self, reward):

        with open(self.rewdir, "wb") as fp:
            pickle.dump(reward, fp)

        plt.clf()
        print('min:', np.min(reward))
        print('max', np.max(reward))
        # fig, ax = plt.subplots()
        plt.figure(figsize=(10, 8))
        high = min(np.max(reward), 5.)
        bins = np.arange(0., high, 0.1)
        plt.hist(reward,
                 bins,
                 alpha=0.5,
                 weights=[1. / len(reward)] * len(reward))  # alpha设置透明度,0为完全透明
        font2 = {
            'size': 18,
        }
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.xlabel('Bias', font2)
        plt.ylabel('Prob', font2)
        plt.grid(True)
        plt.xlim([0.0, high])

        plt.savefig(self.picdir)
        print('save pic path:', self.picdir)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False,
                    compute_r_bias=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        if compute_r_bias:
            return ret[0]
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf, self.main.Q_pi_tf, self.Q_grad_tf, self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def recompute_reward(self, transitions):
        re_transitions = transitions['re_transitions']
        T = re_transitions['u'].shape[1]
        batch_size = re_transitions['u'].shape[0]
        # o(256, 51, 25) u (256, 50, 4) g (256, 50, 3) future_g(256,50,3) info_is_success (256, 50, 1)  o_2 (256, 50, 25) ag_2 (256, 50, 3)
        re_transitions['o'] = re_transitions['o'][:, :T, :].reshape(
            batch_size * T, self.dimo)
        re_transitions['ag'] = re_transitions['ag'][:, :T, :].reshape(
            batch_size * T, self.dimg)
        re_transitions['future_g'] = re_transitions['future_g'].reshape(
            batch_size * T, self.dimg)
        re_transitions['g'] = re_transitions['g'].reshape(
            batch_size * T, self.dimg)
        re_transitions['u'] = re_transitions['u'].reshape(
            batch_size * T, self.dimu)

        u1 = self.get_actions(re_transitions['o'],
                              re_transitions['ag'],
                              re_transitions['future_g'],
                              compute_r_bias=True)
        u2 = self.get_actions(re_transitions['o'],
                              re_transitions['ag'],
                              re_transitions['g'],
                              compute_r_bias=True)
        r_b = self.rb * (np.square(LA.norm(u2 - re_transitions['u'], axis=1)) -
                         np.square(LA.norm(u1 - re_transitions['u'], axis=1)))

        r_b = np.sum(r_b.reshape(batch_size, T), axis=1)

        e_r_b = np.exp(r_b)
        her_indexes = re_transitions['her_index']
        other_indexes = re_transitions['other_index']

        rank = MPI.COMM_WORLD.Get_rank()
        if rank == 0:
            self.total_epoch_r_mean_bias.append(e_r_b.mean())
            self.total_epoch_r_std_bias.append(e_r_b.std())

        if self.isPlot:
            self.save_reward_pic(e_r_b)
            self.isPlot = False

        if self.isMuti:
            transitions['r'] *= np.clip(e_r_b, self.bias_clip_low,
                                        self.bias_clip_high)
        else:  # batch projection
            e_r_b = np.clip(e_r_b, self.bias_clip_low, self.bias_clip_high)
            e_r_b_mean = np.mean(e_r_b[her_indexes])
            transitions['r'][other_indexes] /= e_r_b_mean

        del transitions['re_transitions']
        del transitions['origin_g']
        return transitions

    def sample_batch(self):
        transitions = self.buffer.sample(self.batch_size)
        # lky  recompute reward
        transitions = self.recompute_reward(transitions)

        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
示例#25
0
class DDPG(object):
    @store_args
    def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
                 Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
                 rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
                 sample_transitions, gamma, reuse=False, **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None,)
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *input_shapes[key])
                         for key, val in input_shapes.items()}
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T+1, self.dimg)

        buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g) # clip observations and goals
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf: o.reshape(-1, self.dimo),
            policy.g_tf: g.reshape(-1, self.dimg),
            policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        # ret = action given by the current policy (eval of NN)
        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(*u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        # Below: for each mini-batch we take action u (the one given by the policy) with probability
        # 1-random_eps, and a random action (u + random_action - u) with probability random_eps
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
            transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf,
            self.main.Q_pi_tf,
            self.Q_grad_tf,
            self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        transitions = self.buffer.sample(self.batch_size)
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g)

        transitions_batch = [transitions[key] for key in self.stage_shapes.keys()]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([(key, batch[i])
                                for i, key in enumerate(self.stage_shapes.keys())])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(
                target_batch_tf, net_type='target', **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        # self.XX.pi_tf is the action policy we ll use for exploration (TO CONFIRM)
        # self.XX.Q_pi_tf is the Q network used to train this policy
        # self.XX.Q_tf

        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
        # target y_i= r + gamma*Q part of the Bellman equation (with returns clipped if necessary:
        target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        # loss function for Q_tf where we exclude target_tf from the gradient computation:
        self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))

        # loss function for the action policy is that of the main Q_pi network:
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        # add L2 regularization term from the policy itself:
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))

        # define the gradients of the Q_loss and pi_loss wrt to their variables respectively
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)

        # zip the gradients together with their respective variables
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))

        # flattened gradients and variables
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))

        # optimizers (using MPI for parralel updates of the network (TO CONFIRM))
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)

        # polyak averaging used for the update of the target networks in both pi and Q nets
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
        # operation to initialize the target nets at the main nets'values
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
        # operation to update the target nets from the main nets using polyak averaging
        self.update_target_net_op = list(
            map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()  # CHECK WHAT THIS DOES ????
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats',
                             'main', 'target', 'lock', 'env', 'sample_transitions',
                             'stage_shapes', 'create_actor_critic']

        state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])}
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert(len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 sample_transitions,
                 gamma,
                 gg_k,
                 replay_strategy,
                 reuse=False,
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)
        self.replay_strategy = replay_strategy

        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        if self.replay_strategy == C.REPLAY_STRATEGY_GEN_K:
            self.max_g = kwargs['max_g']
            self.d0 = kwargs['d0']
            self.slope = kwargs['slope']
            self.goal_lr = kwargs['goal_lr']
            # reward shaping parameters
            self.rshape_lambda = kwargs['rshape_lambda']
            self.reshape_p = kwargs['rshape_p']
            self.rshaping = kwargs['rshaping']

            self.input_dims['e'] = self.dimg * self.T
            self.input_dims['mask'] = self.T
            self.dime = self.input_dims['e']
            self.dim_mask = self.input_dims['mask']

        input_shapes = dims_to_shapes(self.input_dims)

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape)
                for shape in self.stage_shapes.values()
            ]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.T if key != 'o' else self.T + 1, *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T + 1, self.dimg)

        if self.replay_strategy in [
                C.REPLAY_STRATEGY_BEST_K, C.REPLAY_STRATEGY_GEN_K,
                C.REPLAY_STRATEGY_GEN_K_GMM
        ]:
            buffer_shapes['gg'] = (self.T, self.gg_k, self.dimg)

        if self.replay_strategy in [
                C.REPLAY_STRATEGY_BEST_K, C.REPLAY_STRATEGY_GEN_K_GMM
        ]:
            buffer_shapes['gg_idx'] = (self.T, self.gg_k)

        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def _preprocess_e(self, e):
        e = np.clip(e, -self.clip_obs, self.clip_obs)
        return e

    # def td_error(self, o, g):
    #     vals = [self.Q_loss_tf]
    def get_target_q_val(self, o, ag, g):
        vals = [self.target.Q_pi_tf]
        feed = {
            self.target.o_tf: o.reshape(-1, self.dimo),
            self.target.g_tf: g.reshape(-1, self.dimg)
        }
        ret = self.sess.run(vals, feed_dict=feed)
        return ret[0]

    def get_goals(self, u_goal, e, mask, use_target_net=False):
        """
        :param u_goal: batch_size * dim_u dimensional array
        :param e: batch_size * (T*dim_g) dimensional array
        :param mask: batch_size * T dimensional array
        :param use_target_net: True/False
        :return:
        """
        e = self._preprocess_e(e)
        policy = self.target if use_target_net else self.main
        vals = [
            policy.goal_tf, policy.distance, policy.e_reshaped,
            policy.goal_tf_repeated, policy.reward_sum
        ]
        # feed
        feed = {
            policy.e_tf: e.reshape(-1, self.dime),
            policy.mask_tf: mask.reshape(-1, self.dim_mask),
            policy.u_tf: u_goal.reshape(-1, self.dimu)
        }
        ret = self.sess.run(vals, feed_dict=feed)
        # print("Generated goal: ")
        # print("Goal: ", ret[0])
        # print("Distance: ", ret[1])
        # print("Episode: ", ret[2])
        # print("Goal repeated: ", ret[3])
        # print("Reward: ", np.average(ret[4]))
        # print('---------------------------------------------------------------')
        # for var in self._vars('main/goal'):
        #     print("Name: " + var.name)
        #     print("Shape: " + str(var.shape))
        #     print(var.eval())
        return ret[0]

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[
                'g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

            if self.replay_strategy == C.REPLAY_STRATEGY_GEN_K:
                e = transitions['e']
                transitions['e'] = self._preprocess_e(e)
                self.e_stats.update(transitions['e'])
                self.e_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()
        if self.replay_strategy == C.REPLAY_STRATEGY_GEN_K:
            self.goal_adam.sync()
            self.Q_goal_adam.sync()
            self.pi_goal_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        tf_list = [
            self.Q_loss_tf, self.main.Q_pi_tf, self.Q_grad_tf, self.pi_grad_tf
        ]
        if self.replay_strategy == C.REPLAY_STRATEGY_GEN_K:
            tf_list.extend([self.goal_loss_tf, self.goal_grad_tf])
            tf_list.extend([self.Q_goal_loss_tf, self.Q_goal_grad_tf])
            tf_list.extend([self.pi_goal_loss_tf, self.pi_goal_grad_tf])
            tf_list.extend([self.main.mask_tf, self.main.d])
        return self.sess.run(tf_list)

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def _update_goal(self, goal_grad, Q_goal_grad, pi_goal_grad):
        # self.Q_goal_adam.update(Q_goal_grad, self.Q_lr)
        # self.pi_goal_adam.update(pi_goal_grad, self.pi_lr)
        self.goal_adam.update(goal_grad, self.goal_lr)

    def sample_batch(self):
        transitions = self.buffer.sample(self.batch_size)
        return self.batch_from_transitions(transitions)

    def batch_from_transitions(self, transitions):
        """
            transitions is a dictionary with keys: ['o', 'ag', 'u', 'o_2', 'ag_2', 'r', 'g']
            batch is a processed batch (normalizing, clipping, relative goal) for staging,
                and has the keys ['o', 'ag', 'u', 'o_2', 'ag_2', 'r', g', 'g_2']
        """
        # preprocess observations and goals
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        if self.replay_strategy == C.REPLAY_STRATEGY_GEN_K:
            e = transitions['e']
            transitions['e'] = self._preprocess_e(e)

        # Set the correct order of keys in the batch
        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        # print("*********************************Training*******************************")
        if stage:
            self.stage_batch()

        if self.replay_strategy != C.REPLAY_STRATEGY_GEN_K:
            critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        else:
            critic_loss, actor_loss, Q_grad, pi_grad,\
            goal_loss, goal_grad, Q_goal_loss, Q_goal_grad, \
            pi_goal_loss, pi_goal_grad, x, y = self._grads()
            self._update_goal(goal_grad, Q_goal_grad, pi_goal_grad)

        self._update(Q_grad, pi_grad)
        # print("Loss: ", goal_loss)
        # print("mask: ", np.sum(x, axis=1))
        # print("distance: ", y)
        # print("Reward: ", r)

        # if self.replay_strategy == C.REPLAY_STRATEGY_GEN_K:
        #     goal_loss = self.sess.run(self.target_Q_goal_tf)
        #     # self.goal_adam.update(goal_grad, self.goal_lr)
        #     print("Goal loss: ", goal_loss)

        if self.replay_strategy == C.REPLAY_STRATEGY_GEN_K:
            self.sess.run(self.copy_normal_to_goal_op)

        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." %
                    (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        if self.replay_strategy == C.REPLAY_STRATEGY_GEN_K:
            # running averages
            with tf.variable_scope('e_stats') as vs:
                if reuse:
                    vs.reuse_variables()
                self.e_stats = Normalizer(self.dime,
                                          self.norm_eps,
                                          self.norm_clip,
                                          sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf_vec = tf.square(
            tf.stop_gradient(target_tf) - self.main.Q_tf)
        self.Q_loss_tf = tf.reduce_mean(self.Q_loss_tf_vec)
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')

        if self.replay_strategy == C.REPLAY_STRATEGY_GEN_K:
            # loss functions for goal generation network
            target_Q_pi_goal_tf = self.target.Q_pi_goal_tf
            target_goal_tf = tf.clip_by_value(
                self.main.reward + self.gamma * target_Q_pi_goal_tf,
                *clip_range)
            self.goal_loss_tf = -self.LAMBDA * tf.reduce_mean(
                tf.square(
                    tf.stop_gradient(target_goal_tf) - self.main.Q_goal_tf))
            # self.goal_loss_tf += 0.0 * tf.reduce_mean(tf.square(self.main.goal_tf / self.max_g))
            # self.goal_loss_tf = 0
            # self.reward_sum = tf.reduce_mean(self.main.reward_sum)
            self.goal_loss_tf += -tf.reduce_mean(self.main.reward_sum)

            # loss functions for Q_goal and pi_goal
            self.Q_goal_loss_tf = tf.reduce_mean(
                tf.square(
                    tf.stop_gradient(target_goal_tf) - self.main.Q_goal_tf))
            self.pi_goal_loss_tf = -tf.reduce_mean(self.main.Q_pi_goal_tf)
            self.pi_goal_loss_tf += self.action_l2 * tf.reduce_mean(
                tf.square(self.main.pi_goal_tf / self.max_u))

            # gradients
            goal_grads_tf = tf.gradients(self.goal_loss_tf,
                                         self._vars('main/goal'))
            self.goal_grad_tf = flatten_grads(grads=goal_grads_tf,
                                              var_list=self._vars('main/goal'))

            Q_goal_grads_tf = tf.gradients(self.Q_goal_loss_tf,
                                           self._vars('main/gQ'))
            self.Q_goal_grad_tf = flatten_grads(grads=Q_goal_grads_tf,
                                                var_list=self._vars('main/gQ'))

            pi_goal_grads_tf = tf.gradients(self.pi_goal_loss_tf,
                                            self._vars('main/gpi'))
            self.pi_goal_grad_tf = flatten_grads(
                grads=pi_goal_grads_tf, var_list=self._vars('main/gpi'))

            assert len(self._vars('main/goal')) == len(goal_grads_tf)
            assert len(self._vars('main/gQ')) == len(Q_goal_grads_tf)
            assert len(self._vars('main/gpi')) == len(pi_goal_grads_tf)

            # optimizers
            self.goal_adam = MpiAdam(self._vars('main/goal'),
                                     scale_grad_by_procs=False)
            self.Q_goal_adam = MpiAdam(self._vars('main/gQ'),
                                       scale_grad_by_procs=False)
            self.pi_goal_adam = MpiAdam(self._vars('main/gpi'),
                                        scale_grad_by_procs=False)

            self.main_vars += self._vars('main/goal') + self._vars(
                'main/gQ') + self._vars('main/gpi')
            self.target_vars += self._vars('target/goal') + self._vars(
                'target/gQ') + self._vars('target/gpi')
            self.stats_vars += self._global_vars('e_stats')

            self.normal_vars = self._vars('main/Q') + self._vars(
                'main/pi') + self._vars('target/Q') + self._vars('target/pi')
            self.goal_vars = self._vars('main/gQ') + self._vars(
                'main/gpi') + self._vars('target/gQ') + self._vars(
                    'target/gpi')

            self.copy_normal_to_goal_op = list(
                map(lambda v: v[0].assign(0 * v[0] + 1 * v[1]),
                    zip(self.goal_vars, self.normal_vars)))

        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)
示例#27
0
class DDPG(object):
    @store_args
    def __init__(self,
                 input_dims,
                 buffer_size,
                 hidden,
                 layers,
                 network_class,
                 polyak,
                 batch_size,
                 Q_lr,
                 pi_lr,
                 norm_eps,
                 norm_clip,
                 max_u,
                 action_l2,
                 clip_obs,
                 scope,
                 T,
                 rollout_batch_size,
                 subtract_goals,
                 relative_goals,
                 clip_pos_returns,
                 clip_return,
                 bc_loss,
                 q_filter,
                 num_demo,
                 demo_batch_size,
                 prm_loss_weight,
                 aux_loss_weight,
                 sample_transitions,
                 gamma,
                 reuse=False,
                 pre_train_model=False,
                 update_model=True,
                 feature_net_path='',
                 **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
            Added functionality to use demonstrations for training to Overcome exploration problem.
        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
            bc_loss: whether or not the behavior cloning loss should be used as an auxiliary loss
            q_filter: whether or not a filter on the q value update should be used when training with demonstartions
            num_demo: Number of episodes in to be used in the demonstration buffer
            demo_batch_size: number of samples to be used from the demonstrations buffer, per mpi thread
            prm_loss_weight: Weight corresponding to the primary loss
            aux_loss_weight: Weight corresponding to the auxiliary loss also called the cloning loss
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        # ADDED
        self.use_contact = (self.contact_dim > 0)
        self.pre_train_model = pre_train_model
        self.feature_net_path = feature_net_path
        self.process_type = kwargs['process_type']
        self.contact_dim = kwargs['contact_dim']
        self.__dict__['use_contact'] = self.use_contact
        self.__dict__['pre_train'] = self.pre_train_model

        self.create_actor_critic = import_function(self.network_class)
        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o'] - self.contact_dim
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']
        self.feature_dim = kwargs['feature_dim']
        self.contact_point_dim = self.contact_dim // self.fixed_num_of_contact

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None, )
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            logger.info("Creating a DDPG agent with action space %d x %s..." %
                        (self.dimu, self.max_u))
            self.sess = tf_util.get_session()
            # order: ['g', 'o', 'u', 'o_2', 'g_2', 'r'])
            if self.pre_train_model == 'cpc':
                self.staging_tf = StagingArea(
                    dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                    shapes=list(self.stage_shapes.values()))
                self.buffer_ph_tf = [
                    tf.placeholder(tf.float32, shape=shape)
                    for shape in self.stage_shapes.values()
                ]
                self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

                self.cpc_shape = OrderedDict()
                self.cpc_shape['obs_neg'] = (None, self.fixed_num_of_contact,
                                             self.contact_point_dim)
                self.cpc_shape['obs_pos'] = (None, self.fixed_num_of_contact,
                                             self.contact_point_dim)
                self.cpc_staging_tf = StagingArea(
                    dtypes=[tf.float32 for _ in self.cpc_shape.keys()],
                    shapes=list(self.cpc_shape.values()))
                self.cpc_buffer_ph_tf = [
                    tf.placeholder(tf.float32, shape=shape)
                    for shape in self.cpc_shape.values()
                ]
                self.cpc_stage_op = self.cpc_staging_tf.put(
                    self.cpc_buffer_ph_tf)
            else:
                self.staging_tf = StagingArea(
                    dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                    shapes=list(self.stage_shapes.values()))
                self.buffer_ph_tf = [
                    tf.placeholder(tf.float32, shape=shape)
                    for shape in self.stage_shapes.values()
                ]
                self.stage_op = self.staging_tf.put(self.buffer_ph_tf)
            self.update_model = update_model

            if self.pre_train_model != 'none':
                self.__dict__['feature_net_path'] = self.feature_net_path
                self.__dict__['clip_obs'] = self.clip_obs

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {
            key: (self.T - 1 if key != 'o' else self.T, *input_shapes[key])
            for key, val in input_shapes.items()
        }
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T, self.dimg)

        buffer_size = (self.buffer_size //
                       self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T,
                                   self.sample_transitions)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u,
                                 high=self.max_u,
                                 size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        # self.clip_obs = 200
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        if len(o.shape) == 1:
            o[-self.dimo:] = np.clip(o[-self.dimo:], -self.clip_obs,
                                     self.clip_obs)
        elif len(o.shape) == 2:
            o[:, -self.dimo:] = np.clip(o[:, -self.dimo:], -self.clip_obs,
                                        self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def step(self, obs):
        actions = self.get_actions(obs['observation'], obs['achieved_goal'],
                                   obs['desired_goal'])
        return actions, None, None, None

    def get_actions(self,
                    o,
                    ag,
                    g,
                    noise_eps=0.,
                    random_eps=0.,
                    use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        """lines added here, remove later"""
        ori = o[:, -7:-4].copy()
        noise = np.random.normal(0, 7e-4, ori.shape)
        o[:, -7:-4] += noise

        feed = {
            policy.o_tf:
            o.reshape(-1, self.dimo + self.contact_dim),
            policy.g_tf:
            g.reshape(-1, self.dimg),
            policy.u_tf:
            np.zeros((o.size // (self.dimo + self.contact_dim), self.dimu),
                     dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)

        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(
            *u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (
            self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """
        """lines added, remove later"""
        ori = episode_batch['o'][:, :, -7:-4].copy()
        noise = np.random.normal(0, 7e-4, ori.shape)
        episode_batch['o'][:, :, -7:-4] += noise

        self.buffer.store_episode(episode_batch)
        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(
                episode_batch)
            # change goals here, recompute rewards
            transitions = self.sample_transitions(episode_batch,
                                                  num_normalizing_transitions)

            o, g, ag = transitions['o'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stat
            """Normalization stuff here. """
            self.o_stats.update(transitions['o'][:, -self.o_stats.size:])
            self.g_stats.update(transitions['g'])
            if self.pre_train_model in ['cpc', 'curl']:
                feed_dict = {self.main.o_tf: transitions['o']}
                features = self.sess.run(self.main.features,
                                         feed_dict=feed_dict)
                features = np.clip(features, -self.clip_obs, self.clip_obs)
                self.feature_stats.update(features)
                self.feature_stats.recompute_stats()
            # elif self.process_type == 'max_pool':
            #     feed_dict = {self.main.o_tf:transitions['o']}
            #     features = self.sess.run(self.main.features, feed_dict=feed_dict)
            #     self.feature_stats.update(features)
            #     self.feature_stats.recompute_stats()

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()
        return transitions['o']

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()
        if self.pre_train_model == 'supervised':
            self.feature_adam.sync()
        elif self.pre_train_model == 'cpc':
            self.cpc.sync()
        elif self.pre_train_model == 'curl':
            self.curl_adam.sync()
            self.encoder_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf, self.main.Q_pi_tf, self.Q_grad_tf, self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        transitions = self.buffer.sample(
            self.batch_size)  #otherwise only sample from primary buffer
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(
            o_2, ag_2, g)

        transitions_batch = [
            transitions[key] for key in self.stage_shapes.keys()
            if key not in ['obs_pos', 'obs_neg']
        ]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        """lines added, remove them later"""
        ori = batch[1][:, -7:-4].copy()
        noise = np.random.normal(0, 7e-4, ori.shape)
        batch[1][:, -7:-4] += noise
        noise = np.random.normal(0, 7e-4, ori.shape)
        batch[3][:, -7:-4] += noise

        self.sess.run(self.stage_op,
                      feed_dict=dict(zip(self.buffer_ph_tf, batch)))

        if self.pre_train_model == 'supervised':
            assert batch[1].shape[1] == 583, "must use full observations"
            # 253, 251, 246, 233, 232, 220, 215, 210
            # feature_loss, max_feature_loss, feature_grad = self.sess.run([self.feature_loss_tf, self.max_feature_loss, self.feature_grad_tf])
            feature_loss, feature_grad = self.sess.run(
                [self.feature_loss_tf, self.feature_grad_tf])
            self.feature_adam.update(feature_grad, 1e-3)
            self.sess.run(self.update_feature_weights_target)
            self.sess.run(self.stage_op,
                          feed_dict=dict(zip(self.buffer_ph_tf, batch)))
            # writer = tf.summary.FileWriter("home/vioichigo/try/tactile-baselines/graph", self.sess.graph)
            # print(self.sess.run(self.main.features))
            # writer.close()
            return feature_loss
        elif self.pre_train_model == 'cpc':
            # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            # run_metadata = tf.RunMetadata()
            # obs = pickle.load(open('/home/vioichigo/try/tactile-baselines/dataset/HandManipulateEgg-v0/50000obs.pickle', 'rb'))
            # indices = np.random.randint(obs.shape[0], size=batch[1].shape[0] * (self.main.n_neg - 1))
            # obs_neg = obs[indices]
            # obs_pos = batch[3][:, :self.contact_dim].reshape((-1, self.fixed_num_of_contact, self.contact_dim//self.fixed_num_of_contact))
            # # self.sess.run(self.cpc_stage_op, feed_dict=dict(zip(self.cpc_buffer_ph_tf, [obs_neg, obs_pos])), options=run_options, run_metadata=run_metadata)
            # first = time.time()
            # # self.sess.run(self.cpc_stage_op, feed_dict=dict(zip(self.cpc_buffer_ph_tf, [obs_neg, obs_pos])))
            # start = time.time()
            # print("feed:", start - first)
            # feed_dict = {self.cpc_inputs_tf['obs_pos']: obs_pos, self.cpc_inputs_tf['obs_neg']: obs_neg}
            # # dict(zip(self.cpc_inputs_tf, [obs_neg, obs_pos]))
            # cpc_loss, cpc_grad = self.sess.run([self.cpc_loss_tf, self.cpc_grad_tf], feed_dict=feed_dict, options=run_options, run_metadata=run_metadata)
            # tl = timeline.Timeline(run_metadata.step_stats)
            # ctf = tl.generate_chrome_trace_format()
            # with open('./timeline.json', 'w') as f:
            #     f.write(ctf)
            # now = time.time()
            # print("compute_loss", now - start)
            # self.cpc_adam.update(cpc_grad, 1e-3)
            # print("update weights", time.time() - now)
            # self.sess.run(self.update_cpc_weights_target)
            # self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))

            return 1
        elif self.pre_train_model == 'curl':
            curl_loss, curl_grad, encoder_grad = self.sess.run(
                [self.curl_loss, self.curl_grad_tf, self.encoder_grad_tf])
            self.curl_adam.update(curl_grad, 1e-3)
            self.encoder_adam.update(encoder_grad, 1e-3)

            self.sess.run(self.stage_op,
                          feed_dict=dict(zip(self.buffer_ph_tf, batch)))
            self.sess.run(self.update_curl_weights_op)

            return curl_loss

            # return cpc_loss

    def train(self, stage=True):
        if stage:
            if self.pre_train_model == 'none':
                self.stage_batch()
            else:
                feature_loss = self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        if self.pre_train_model == 'none':
            return critic_loss, actor_loss
        else:
            return critic_loss, actor_loss, feature_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope=self.scope + '/' + scope)
        if self.pre_train_model == 'supervised':
            if not self.update_model:
                res = [x for x in res if x.name.find('predicted_pos') == -1]
        elif self.pre_train_model == 'cpc':
            if not self.update_model:
                res = [x for x in res if x.name.find('new_cpc') == -1]
        # elif self.pre_train_model == 'curl':
        #     res = [x for x in res if x.name.find('W') == -1]
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        # running averages
        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        self.increment_global_step = tf.assign_add(
            self.global_step, 1, name='increment_global_step')
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            """Normalization stuff here. """
            if self.use_contact and self.process_type in ['none', 'test']:
                self.o_stats = Normalizer(self.dimo + self.contact_dim,
                                          self.norm_eps,
                                          self.norm_clip,
                                          sess=self.sess)
            else:
                self.o_stats = Normalizer(self.dimo,
                                          self.norm_eps,
                                          self.norm_clip,
                                          sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg,
                                      self.norm_eps,
                                      self.norm_clip,
                                      sess=self.sess)

        if self.pre_train_model == 'cpc':
            with tf.variable_scope('feature_stats') as vs:
                if reuse:
                    vs.reuse_variables()
                z_dim = pickle.load(
                    open(self.feature_net_path + 'params.pickle', 'rb'))[0]
                self.feature_stats = Normalizer(z_dim,
                                                self.norm_eps,
                                                self.norm_clip,
                                                sess=self.sess)
                self.__dict__['feature_normalizer'] = self.feature_stats
        elif self.pre_train_model == 'curl':
            with tf.variable_scope('feature_stats') as vs:
                if reuse:
                    vs.reuse_variables()
                self.feature_stats = Normalizer(32,
                                                self.norm_eps,
                                                self.norm_clip,
                                                sess=self.sess)
                self.__dict__['feature_normalizer'] = self.feature_stats

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([
            (key, batch[i]) for i, key in enumerate(self.stage_shapes.keys())
        ])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        if self.pre_train_model == 'cpc':
            cpc_batch = self.cpc_staging_tf.get()
            cpc_batch_tf = OrderedDict([
                (key, cpc_batch[i])
                for i, key in enumerate(self.cpc_shape.keys())
            ])
            # self.cpc_batch_tf = {}
            # self.cpc_batch_tf['obs_neg'] = tf.placeholder(tf.float32, shape=(None, self.fixed_num_of_contact, self.contact_point_dim))
            # self.cpc_batch_tf['obs_pos'] = tf.placeholder(tf.float32, shape=(None, self.fixed_num_of_contact, self.contact_point_dim))
            # self.__dict__['cpc_inputs_tf'] = self.cpc_batch_tf

        #choose only the demo buffer samples

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf,
                                                 net_type='main',
                                                 **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            # reuse = False
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']  #next_observations
            target_batch_tf['g'] = batch_tf['g_2']  #next_goals
            self.target = self.create_actor_critic(target_batch_tf,
                                                   net_type='target',
                                                   **self.__dict__)
            vs.reuse_variables()

        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return,
                      0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(
            batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf = tf.reduce_mean(
            tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))

        # else: #If  not training with demonstrations
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(
            tf.square(self.main.pi_tf / self.max_u))

        if self.pre_train_model == 'supervised':
            self.feature_net_var = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES,
                scope='ddpg/main/pi/process/predicted_pos')
            pos = batch_tf['o'][:, self.contact_dim:][:, -7:-4]
            self.feature_loss_tf = tf.reduce_mean(
                tf.square(pos - self.main.features))
            # self.max_feature_loss = tf.reduce_max(tf.square(pos - self.main.features))
            feature_grads_tf = tf.gradients(self.feature_loss_tf,
                                            self.feature_net_var)
            assert len(self.feature_net_var) == len(feature_grads_tf)
            self.feature_grads_vars_tf = zip(feature_grads_tf,
                                             self.feature_net_var)
            self.feature_grad_tf = flatten_grads(grads=feature_grads_tf,
                                                 var_list=self.feature_net_var)
            self.feature_adam = MpiAdam(self.feature_net_var,
                                        scale_grad_by_procs=False)

            target_vars = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES,
                scope='ddpg/target/pi/process/predicted_pos')
            self.update_feature_weights_target = [
                tf.assign(new, old)
                for (new, old) in zip(target_vars, self.feature_net_var)
            ]
        elif self.pre_train_model == 'cpc':
            self.cpc_var = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES,
                scope='ddpg/main/pi/process/new_cpc')
            pos = tf.reshape(batch_tf['o_2'][:, :self.contact_dim], [
                -1, self.fixed_num_of_contact,
                self.contact_dim // self.fixed_num_of_contact
            ])
            with tf.variable_scope('auxiliary'):
                self.cpc_loss_tf = compute_cpc_loss(
                    self.main.z_dim,
                    self.main.pos_features,
                    self.main.neg_features,
                    self.main.next,
                    process_type=self.process_type,
                    n_neg=self.main.n_neg,
                    type=self.main.type)
            cpc_grads_tf = tf.gradients(self.cpc_loss_tf, self.cpc_var)
            assert len(self.cpc_var) == len(cpc_grads_tf)
            self.cpc_grads_vars_tf = zip(cpc_grads_tf, self.cpc_var)
            self.cpc_grad_tf = flatten_grads(grads=cpc_grads_tf,
                                             var_list=self.cpc_var)
            self.cpc_adam = MpiAdam(self.cpc_var, scale_grad_by_procs=False)

            target_vars = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES,
                scope='ddpg/target/pi/process/new_cpc')
            self.update_cpc_weights_target = [
                tf.assign(new, old)
                for (new, old) in zip(target_vars, self.cpc_var)
            ]

        elif self.pre_train_model == 'curl':
            self.W = tf.get_variable("W",
                                     shape=[self.main.z_dim, self.main.z_dim],
                                     trainable=True)
            self.encoder_var = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope='ddpg/main/pi/curl')
            self.encoder_adam = MpiAdam(self.encoder_var,
                                        scale_grad_by_procs=False)
            self.curl_var = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES,
                scope='ddpg/main/pi/curl') + [self.W]
            # + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='ddpg/target/pi/curl')
            self.curl_adam = MpiAdam(self.curl_var, scale_grad_by_procs=False)

            z_a = self.main.features
            z_pos = tf.stop_gradient(self.target.features)

            Wz = tf.matmul(self.W, tf.transpose(z_pos))  # (z_dim,B)
            logits = tf.matmul(z_a, Wz)  # (B,B)
            logits = logits - tf.reduce_max(logits, 1)[:, None]
            labels = tf.range(tf.shape(logits)[0])
            self.curl_loss = tf.nn.softmax_cross_entropy_with_logits(
                logits=logits, labels=labels)

            target_curl_vars = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope='ddpg/target/pi/curl')

            self.update_curl_weights_op = list(
                map(
                    lambda v: v[0].assign(self.polyak * v[0] +
                                          (1. - self.polyak) * v[1]),
                    zip(target_curl_vars, self.encoder_var)))

            curl_grads_tf = tf.gradients(self.curl_loss, self.curl_var)
            self.curl_grad_tf = flatten_grads(grads=curl_grads_tf,
                                              var_list=self.curl_var)
            encoder_grads_tf = tf.gradients(self.curl_loss, self.encoder_var)
            self.encoder_grad_tf = flatten_grads(grads=encoder_grads_tf,
                                                 var_list=self.encoder_var)

        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf,
                                       var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf,
                                        var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'),
                               scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')

        self.stats_vars = self._global_vars('o_stats') + self._global_vars(
            'g_stats')

        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]),
                zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(
                lambda v: v[0].assign(self.polyak * v[0] +
                                      (1. - self.polyak) * v[1]),
                zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix != '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = [
            '_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats', 'main',
            'target', 'lock', 'env', 'sample_transitions', 'stage_shapes',
            'create_actor_critic'
        ]

        state = {
            k: v
            for k, v in self.__dict__.items()
            if all([not subname in k for subname in excluded_subnames])
        }
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run(
            [x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert (len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)

    def save(self, save_path):
        tf_util.save_variables(save_path)
示例#28
0
class DDPG(object):
    @store_args
    def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
                 Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
                 rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
                 sample_transitions, gamma, reuse=False, **kwargs):
        """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).

        Args:
            input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
                actions (u)
            buffer_size (int): number of transitions that are stored in the replay buffer
            hidden (int): number of units in the hidden layers
            layers (int): number of hidden layers
            network_class (str): the network class that should be used (e.g. 'baselines.her.ActorCritic')
            polyak (float): coefficient for Polyak-averaging of the target network
            batch_size (int): batch size for training
            Q_lr (float): learning rate for the Q (critic) network
            pi_lr (float): learning rate for the pi (actor) network
            norm_eps (float): a small value used in the normalizer to avoid numerical instabilities
            norm_clip (float): normalized inputs are clipped to be in [-norm_clip, norm_clip]
            max_u (float): maximum action magnitude, i.e. actions are in [-max_u, max_u]
            action_l2 (float): coefficient for L2 penalty on the actions
            clip_obs (float): clip observations before normalization to be in [-clip_obs, clip_obs]
            scope (str): the scope used for the TensorFlow graph
            T (int): the time horizon for rollouts
            rollout_batch_size (int): number of parallel rollouts per DDPG agent
            subtract_goals (function): function that subtracts goals from each other
            relative_goals (boolean): whether or not relative goals should be fed into the network
            clip_pos_returns (boolean): whether or not positive returns should be clipped
            clip_return (float): clip returns to be in [-clip_return, clip_return]
            sample_transitions (function) function that samples from the replay buffer
            gamma (float): gamma used for Q learning updates
            reuse (boolean): whether or not the networks should be reused
        """
        if self.clip_return is None:
            self.clip_return = np.inf

        self.create_actor_critic = import_function(self.network_class)

        input_shapes = dims_to_shapes(self.input_dims)
        self.dimo = self.input_dims['o']
        self.dimg = self.input_dims['g']
        self.dimu = self.input_dims['u']

        # Prepare staging area for feeding data to the model.
        stage_shapes = OrderedDict()
        for key in sorted(self.input_dims.keys()):
            if key.startswith('info_'):
                continue
            stage_shapes[key] = (None, *input_shapes[key])
        for key in ['o', 'g']:
            stage_shapes[key + '_2'] = stage_shapes[key]
        stage_shapes['r'] = (None,)
        self.stage_shapes = stage_shapes

        # Create network.
        with tf.variable_scope(self.scope):
            self.staging_tf = StagingArea(
                dtypes=[tf.float32 for _ in self.stage_shapes.keys()],
                shapes=list(self.stage_shapes.values()))
            self.buffer_ph_tf = [
                tf.placeholder(tf.float32, shape=shape) for shape in self.stage_shapes.values()]
            self.stage_op = self.staging_tf.put(self.buffer_ph_tf)

            self._create_network(reuse=reuse)

        # Configure the replay buffer.
        buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *input_shapes[key])
                         for key, val in input_shapes.items()}
        buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
        buffer_shapes['ag'] = (self.T+1, self.dimg)

        buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
        self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)

    def _random_action(self, n):
        return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))

    def _preprocess_og(self, o, ag, g):
        if self.relative_goals:
            g_shape = g.shape
            g = g.reshape(-1, self.dimg)
            ag = ag.reshape(-1, self.dimg)
            g = self.subtract_goals(g, ag)
            g = g.reshape(*g_shape)
        o = np.clip(o, -self.clip_obs, self.clip_obs)
        g = np.clip(g, -self.clip_obs, self.clip_obs)
        return o, g

    def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
                    compute_Q=False):
        o, g = self._preprocess_og(o, ag, g)
        policy = self.target if use_target_net else self.main
        # values to compute
        vals = [policy.pi_tf]
        if compute_Q:
            vals += [policy.Q_pi_tf]
        # feed
        feed = {
            policy.o_tf: o.reshape(-1, self.dimo),
            policy.g_tf: g.reshape(-1, self.dimg),
            policy.u_tf: np.zeros((o.size // self.dimo, self.dimu), dtype=np.float32)
        }

        ret = self.sess.run(vals, feed_dict=feed)
        # action postprocessing
        u = ret[0]
        noise = noise_eps * self.max_u * np.random.randn(*u.shape)  # gaussian noise
        u += noise
        u = np.clip(u, -self.max_u, self.max_u)
        u += np.random.binomial(1, random_eps, u.shape[0]).reshape(-1, 1) * (self._random_action(u.shape[0]) - u)  # eps-greedy
        if u.shape[0] == 1:
            u = u[0]
        u = u.copy()
        ret[0] = u

        if len(ret) == 1:
            return ret[0]
        else:
            return ret

    def store_episode(self, episode_batch, update_stats=True):
        """
        episode_batch: array of batch_size x (T or T+1) x dim_key
                       'o' is of size T+1, others are of size T
        """

        self.buffer.store_episode(episode_batch)

        if update_stats:
            # add transitions to normalizer
            episode_batch['o_2'] = episode_batch['o'][:, 1:, :]
            episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :]
            num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
            transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)

            o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
            transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
            # No need to preprocess the o_2 and g_2 since this is only used for stats

            self.o_stats.update(transitions['o'])
            self.g_stats.update(transitions['g'])

            self.o_stats.recompute_stats()
            self.g_stats.recompute_stats()

    def get_current_buffer_size(self):
        return self.buffer.get_current_size()

    def _sync_optimizers(self):
        self.Q_adam.sync()
        self.pi_adam.sync()

    def _grads(self):
        # Avoid feed_dict here for performance!
        critic_loss, actor_loss, Q_grad, pi_grad = self.sess.run([
            self.Q_loss_tf,
            self.main.Q_pi_tf,
            self.Q_grad_tf,
            self.pi_grad_tf
        ])
        return critic_loss, actor_loss, Q_grad, pi_grad

    def _update(self, Q_grad, pi_grad):
        self.Q_adam.update(Q_grad, self.Q_lr)
        self.pi_adam.update(pi_grad, self.pi_lr)

    def sample_batch(self):
        transitions = self.buffer.sample(self.batch_size)
        o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
        ag, ag_2 = transitions['ag'], transitions['ag_2']
        transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
        transitions['o_2'], transitions['g_2'] = self._preprocess_og(o_2, ag_2, g)

        transitions_batch = [transitions[key] for key in self.stage_shapes.keys()]
        return transitions_batch

    def stage_batch(self, batch=None):
        if batch is None:
            batch = self.sample_batch()
        assert len(self.buffer_ph_tf) == len(batch)
        self.sess.run(self.stage_op, feed_dict=dict(zip(self.buffer_ph_tf, batch)))

    def train(self, stage=True):
        if stage:
            self.stage_batch()
        critic_loss, actor_loss, Q_grad, pi_grad = self._grads()
        self._update(Q_grad, pi_grad)
        return critic_loss, actor_loss

    def _init_target_net(self):
        self.sess.run(self.init_target_net_op)

    def update_target_net(self):
        self.sess.run(self.update_target_net_op)

    def clear_buffer(self):
        self.buffer.clear_buffer()

    def _vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope + '/' + scope)
        assert len(res) > 0
        return res

    def _global_vars(self, scope):
        res = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.scope + '/' + scope)
        return res

    def _create_network(self, reuse=False):
        logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))

        self.sess = tf.get_default_session()
        if self.sess is None:
            self.sess = tf.InteractiveSession()

        # running averages
        with tf.variable_scope('o_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.o_stats = Normalizer(self.dimo, self.norm_eps, self.norm_clip, sess=self.sess)
        with tf.variable_scope('g_stats') as vs:
            if reuse:
                vs.reuse_variables()
            self.g_stats = Normalizer(self.dimg, self.norm_eps, self.norm_clip, sess=self.sess)

        # mini-batch sampling.
        batch = self.staging_tf.get()
        batch_tf = OrderedDict([(key, batch[i])
                                for i, key in enumerate(self.stage_shapes.keys())])
        batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

        # networks
        with tf.variable_scope('main') as vs:
            if reuse:
                vs.reuse_variables()
            self.main = self.create_actor_critic(batch_tf, net_type='main', **self.__dict__)
            vs.reuse_variables()
        with tf.variable_scope('target') as vs:
            if reuse:
                vs.reuse_variables()
            target_batch_tf = batch_tf.copy()
            target_batch_tf['o'] = batch_tf['o_2']
            target_batch_tf['g'] = batch_tf['g_2']
            self.target = self.create_actor_critic(
                target_batch_tf, net_type='target', **self.__dict__)
            vs.reuse_variables()
        assert len(self._vars("main")) == len(self._vars("target"))

        # loss functions
        target_Q_pi_tf = self.target.Q_pi_tf
        clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
        target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
        self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
        self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
        self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
        Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
        pi_grads_tf = tf.gradients(self.pi_loss_tf, self._vars('main/pi'))
        assert len(self._vars('main/Q')) == len(Q_grads_tf)
        assert len(self._vars('main/pi')) == len(pi_grads_tf)
        self.Q_grads_vars_tf = zip(Q_grads_tf, self._vars('main/Q'))
        self.pi_grads_vars_tf = zip(pi_grads_tf, self._vars('main/pi'))
        self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self._vars('main/Q'))
        self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self._vars('main/pi'))

        # optimizers
        self.Q_adam = MpiAdam(self._vars('main/Q'), scale_grad_by_procs=False)
        self.pi_adam = MpiAdam(self._vars('main/pi'), scale_grad_by_procs=False)

        # polyak averaging
        self.main_vars = self._vars('main/Q') + self._vars('main/pi')
        self.target_vars = self._vars('target/Q') + self._vars('target/pi')
        self.stats_vars = self._global_vars('o_stats') + self._global_vars('g_stats')
        self.init_target_net_op = list(
            map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
        self.update_target_net_op = list(
            map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), zip(self.target_vars, self.main_vars)))

        # initialize all variables
        tf.variables_initializer(self._global_vars('')).run()
        self._sync_optimizers()
        self._init_target_net()

    def logs(self, prefix=''):
        logs = []
        logs += [('stats_o/mean', np.mean(self.sess.run([self.o_stats.mean])))]
        logs += [('stats_o/std', np.mean(self.sess.run([self.o_stats.std])))]
        logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
        logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]

        if prefix is not '' and not prefix.endswith('/'):
            return [(prefix + '/' + key, val) for key, val in logs]
        else:
            return logs

    def __getstate__(self):
        """Our policies can be loaded from pkl, but after unpickling you cannot continue training.
        """
        excluded_subnames = ['_tf', '_op', '_vars', '_adam', 'buffer', 'sess', '_stats',
                             'main', 'target', 'lock', 'env', 'sample_transitions',
                             'stage_shapes', 'create_actor_critic']

        state = {k: v for k, v in self.__dict__.items() if all([not subname in k for subname in excluded_subnames])}
        state['buffer_size'] = self.buffer_size
        state['tf'] = self.sess.run([x for x in self._global_vars('') if 'buffer' not in x.name])
        return state

    def __setstate__(self, state):
        if 'sample_transitions' not in state:
            # We don't need this for playing the policy.
            state['sample_transitions'] = None

        self.__init__(**state)
        # set up stats (they are overwritten in __init__)
        for k, v in state.items():
            if k[-6:] == '_stats':
                self.__dict__[k] = v
        # load TF variables
        vars = [x for x in self._global_vars('') if 'buffer' not in x.name]
        assert(len(vars) == len(state["tf"]))
        node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
        self.sess.run(node)