示例#1
0
文件: ddpg.py 项目: gntoni/garage
class DDPG(RLAlgorithm):
    """
    A DDPG model based on https://arxiv.org/pdf/1509.02971.pdf.

    Example:
        $ python garage/examples/tf/ddpg_pendulum.py
    """
    def __init__(self,
                 env,
                 actor,
                 critic,
                 n_epochs=500,
                 n_epoch_cycles=20,
                 n_rollout_steps=100,
                 n_train_steps=50,
                 reward_scale=1.,
                 batch_size=64,
                 target_update_tau=0.01,
                 discount=0.99,
                 actor_lr=1e-4,
                 critic_lr=1e-3,
                 actor_weight_decay=0,
                 critic_weight_decay=0,
                 replay_buffer_size=int(1e6),
                 min_buffer_size=10000,
                 exploration_strategy=None,
                 plot=False,
                 pause_for_plot=False,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 name=None):
        """
        Construct class.

        Args:
            env(): Environment.
            actor(garage.tf.policies.ContinuousMLPPolicy): Policy network.
            critic(garage.tf.q_functions.ContinuousMLPQFunction):
         Q Value network.
            n_epochs(int, optional): Number of epochs.
            n_epoch_cycles(int, optional): Number of epoch cycles.
            n_rollout_steps(int, optional): Number of rollout steps.
            n_train_steps(int, optional): Number of train steps.
            reward_scale(float): The scaling factor applied to the rewards when
         training.
            batch_size(int): Number of samples for each minibatch.
            target_update_tau(float): Interpolation parameter for doing the
         soft target update.
            discount(float): Discount factor for the cumulative return.
            actor_lr(float): Learning rate for training policy network.
            critic_lr(float): Learning rate for training q value network.
            actor_weight_decay(float): L2 weight decay factor for parameters of
         the policy network.
            critic_weight_decay(float): L2 weight decay factor for parameters
         of the q value network.
            replay_buffer_size(int): Size of the replay buffer.
            min_buffer_size(int): Minimum size of the replay buffer to start
         training.
            exploration_strategy(): Exploration strategy.
            plot(bool): Whether to visualize the policy performance after each
         eval_interval.
            pause_for_plot(bool): Whether to pause before continuing when
         plotting.
            actor_optimizer(): Optimizer for training policy network.
            critic_optimizer(): Optimizer for training q function network.
        """
        self.env = env

        self.observation_dim = flat_dim(env.observation_space)
        self.action_dim = flat_dim(env.action_space)
        _, self.action_bound = bounds(env.action_space)

        self.actor = actor
        self.critic = critic
        self.n_epochs = n_epochs
        self.n_epoch_cycles = n_epoch_cycles
        self.n_rollout_steps = n_rollout_steps
        self.n_train_steps = n_train_steps
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.tau = target_update_tau
        self.discount = discount
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.actor_weight_decay = actor_weight_decay
        self.critic_weight_decay = critic_weight_decay
        self.replay_buffer_size = replay_buffer_size
        self.min_buffer_size = min_buffer_size
        self.es = exploration_strategy
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.name = name
        self._initialize()

    @overrides
    def train(self, sess=None):
        """
        Training process of DDPG algorithm.

        Args:
            sess: A TensorFlow session for executing ops.
        """
        replay_buffer = self.opt_info["replay_buffer"]
        f_init_target = self.opt_info["f_init_target"]
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        # Start plotter
        if self.plot:
            self.plotter = Plotter(self.env, self.actor, sess)
            self.plotter.start()

        sess.run(tf.global_variables_initializer())
        f_init_target()

        observation = self.env.reset()
        if self.es:
            self.es.reset()

        episode_reward = 0.
        episode_step = 0
        episode_rewards = []
        episode_steps = []
        episode_actor_losses = []
        episode_critic_losses = []
        episodes = 0
        epoch_ys = []
        epoch_qs = []

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            for epoch_cycle in pyprind.prog_bar(range(self.n_epoch_cycles)):
                for rollout in range(self.n_rollout_steps):
                    action = self.es.get_action(rollout, observation,
                                                self.actor)
                    assert action.shape == self.env.action_space.shape

                    next_observation, reward, terminal, info = self.env.step(
                        action)
                    episode_reward += reward
                    episode_step += 1

                    replay_buffer.add_transition(observation, action,
                                                 reward * self.reward_scale,
                                                 terminal, next_observation)

                    observation = next_observation

                    if terminal:
                        episode_rewards.append(episode_reward)
                        episode_steps.append(episode_step)
                        episode_reward = 0.
                        episode_step = 0
                        episodes += 1

                        observation = self.env.reset()
                        if self.es:
                            self.es.reset()

                for train_itr in range(self.n_train_steps):
                    if replay_buffer.size >= self.min_buffer_size:
                        critic_loss, y, q, action_loss = self._learn()

                        episode_actor_losses.append(action_loss)
                        episode_critic_losses.append(critic_loss)
                        epoch_ys.append(y)
                        epoch_qs.append(q)

            logger.log("Training finished")
            if replay_buffer.size >= self.min_buffer_size:
                logger.record_tabular('Epoch', epoch)
                logger.record_tabular('Episodes', episodes)
                logger.record_tabular('AverageReturn',
                                      np.mean(episode_rewards))
                logger.record_tabular('StdReturn', np.std(episode_rewards))
                logger.record_tabular('Policy/AveragePolicyLoss',
                                      np.mean(episode_actor_losses))
                logger.record_tabular('QFunction/AverageQFunctionLoss',
                                      np.mean(episode_critic_losses))
                logger.record_tabular('QFunction/AverageQ', np.mean(epoch_qs))
                logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs))
                logger.record_tabular('QFunction/AverageAbsQ',
                                      np.mean(np.abs(epoch_qs)))
                logger.record_tabular('QFunction/AverageY', np.mean(epoch_ys))
                logger.record_tabular('QFunction/MaxY', np.max(epoch_ys))
                logger.record_tabular('QFunction/AverageAbsY',
                                      np.mean(np.abs(epoch_ys)))

                # Uncomment the following if you want to calculate the average
                # in each epoch
                # episode_rewards = []
                # episode_actor_losses = []
                # episode_critic_losses = []
                # epoch_ys = []
                # epoch_qs = []

            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.actor, self.n_rollout_steps)
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")

        if self.plot:
            self.plotter.shutdown()
        if created_session:
            sess.close()

    def _initialize(self):
        with tf.name_scope(self.name, "DDPG"):
            with tf.name_scope("setup_networks"):
                """Set up the actor, critic and target network."""
                # Set up the actor and critic network
                self.actor._build_net(trainable=True)
                self.critic._build_net(trainable=True)

                # Create target actor and critic network
                target_actor = copy(self.actor)
                target_critic = copy(self.critic)

                # Set up the target network
                target_actor.name = "TargetActor"
                target_actor._build_net(trainable=False)
                target_critic.name = "TargetCritic"
                target_critic._build_net(trainable=False)

            # Initialize replay buffer
            replay_buffer = ReplayBuffer(self.replay_buffer_size,
                                         self.observation_dim, self.action_dim)

            # Set up target init and update function
            with tf.name_scope("setup_target"):
                actor_init_ops, actor_update_ops = get_target_ops(
                    self.actor.global_vars, target_actor.global_vars, self.tau)
                critic_init_ops, critic_update_ops = get_target_ops(
                    self.critic.global_vars, target_critic.global_vars,
                    self.tau)
                target_init_op = actor_init_ops + critic_init_ops
                target_update_op = actor_update_ops + critic_update_ops

            f_init_target = tensor_utils.compile_function(
                inputs=[], outputs=target_init_op)
            f_update_target = tensor_utils.compile_function(
                inputs=[], outputs=target_update_op)

            with tf.name_scope("inputs"):
                y = tf.placeholder(tf.float32, shape=(None, 1), name="input_y")
                obs = tf.placeholder(tf.float32,
                                     shape=(None, self.observation_dim),
                                     name="input_observation")
                actions = tf.placeholder(tf.float32,
                                         shape=(None, self.action_dim),
                                         name="input_action")

            # Set up actor training function
            next_action = self.actor.get_action_sym(obs, name="actor_action")
            next_qval = self.critic.get_qval_sym(obs,
                                                 next_action,
                                                 name="actor_qval")
            with tf.name_scope("action_loss"):
                action_loss = -tf.reduce_mean(next_qval)
                if self.actor_weight_decay > 0.:
                    actor_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.actor_weight_decay),
                        weights_list=self.actor.regularizable_vars)
                    action_loss += actor_reg

            with tf.name_scope("minimize_action_loss"):
                actor_train_op = self.actor_optimizer(
                    self.actor_lr, name="ActorOptimizer").minimize(
                        action_loss, var_list=self.actor.trainable_vars)

            f_train_actor = tensor_utils.compile_function(
                inputs=[obs], outputs=[actor_train_op, action_loss])

            # Set up critic training function
            qval = self.critic.get_qval_sym(obs, actions, name="q_value")
            with tf.name_scope("qval_loss"):
                qval_loss = tf.reduce_mean(tf.squared_difference(y, qval))
                if self.critic_weight_decay > 0.:
                    critic_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.critic_weight_decay),
                        weights_list=self.critic.regularizable_vars)
                    qval_loss += critic_reg

            with tf.name_scope("minimize_critic_loss"):
                critic_train_op = self.critic_optimizer(
                    self.critic_lr, name="CriticOptimizer").minimize(
                        qval_loss, var_list=self.critic.trainable_vars)

            f_train_critic = tensor_utils.compile_function(
                inputs=[y, obs, actions],
                outputs=[critic_train_op, qval_loss, qval])

            self.opt_info = dict(f_train_actor=f_train_actor,
                                 f_train_critic=f_train_critic,
                                 f_init_target=f_init_target,
                                 f_update_target=f_update_target,
                                 replay_buffer=replay_buffer,
                                 target_critic=target_critic,
                                 target_actor=target_actor)

    def _learn(self):
        """
        Perform algorithm optimizing.

        Returns:
            action_loss: Loss of action predicted by the policy network.
            qval_loss: Loss of q value predicted by the q network.
            ys: y_s.
            qval: Q value predicted by the q network.

        """
        replay_buffer = self.opt_info["replay_buffer"]
        target_actor = self.opt_info["target_actor"]
        target_critic = self.opt_info["target_critic"]
        f_train_critic = self.opt_info["f_train_critic"]
        f_train_actor = self.opt_info["f_train_actor"]
        f_update_target = self.opt_info["f_update_target"]

        transitions = replay_buffer.random_sample(self.batch_size)
        observations = transitions["observations"]
        rewards = transitions["rewards"]
        actions = transitions["actions"]
        terminals = transitions["terminals"]
        next_observations = transitions["next_observations"]

        rewards = rewards.reshape(-1, 1)
        terminals = terminals.reshape(-1, 1)

        target_actions = target_actor.get_actions(next_observations)
        target_qvals = target_critic.get_qval(next_observations,
                                              target_actions)

        ys = rewards + (1.0 - terminals) * self.discount * target_qvals

        _, qval_loss, qval = f_train_critic(ys, observations, actions)
        _, action_loss = f_train_actor(observations)
        f_update_target()

        return qval_loss, ys, qval, action_loss
示例#2
0
class BatchPolopt(RLAlgorithm):
    """
    Base class for batch sampling-based policy optimization methods.
    This includes various policy gradient methods like vpg, npg, ppo, trpo,
    etc.
    """
    def __init__(self,
                 env,
                 policy,
                 baseline,
                 scope=None,
                 n_itr=500,
                 start_itr=0,
                 batch_size=5000,
                 max_path_length=500,
                 discount=0.99,
                 gae_lambda=1,
                 plot=False,
                 pause_for_plot=False,
                 center_adv=True,
                 positive_adv=False,
                 store_paths=False,
                 whole_paths=True,
                 fixed_horizon=False,
                 sampler_cls=None,
                 sampler_args=None,
                 force_batch_sampler=False,
                 **kwargs):
        """
        :param env: Environment
        :param policy: Policy
        :type policy: Policy
        :param baseline: Baseline
        :param scope: Scope for identifying the algorithm. Must be specified if
         running multiple algorithms
        simultaneously, each using different environments and policies
        :param n_itr: Number of iterations.
        :param start_itr: Starting iteration.
        :param batch_size: Number of samples per iteration.
        :param max_path_length: Maximum length of a single rollout.
        :param discount: Discount.
        :param gae_lambda: Lambda used for generalized advantage estimation.
        :param plot: Plot evaluation run after each iteration.
        :param pause_for_plot: Whether to pause before contiuing when plotting.
        :param center_adv: Whether to rescale the advantages so that they have
         mean 0 and standard deviation 1.
        :param positive_adv: Whether to shift the advantages so that they are
         always positive. When used in conjunction with center_adv the
         advantages will be standardized before shifting.
        :param store_paths: Whether to save all paths data to the snapshot.
        :return:
        """
        self.env = env
        self.policy = policy
        self.baseline = baseline
        self.scope = scope
        self.n_itr = n_itr
        self.start_itr = start_itr
        self.batch_size = batch_size
        self.max_path_length = max_path_length
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.center_adv = center_adv
        self.positive_adv = positive_adv
        self.store_paths = store_paths
        self.whole_paths = whole_paths
        self.fixed_horizon = fixed_horizon
        if sampler_cls is None:
            if self.policy.vectorized and not force_batch_sampler:
                sampler_cls = VectorizedSampler
            else:
                sampler_cls = BatchSampler
        if sampler_args is None:
            sampler_args = dict()
        self.sampler = sampler_cls(self, **sampler_args)
        self.init_opt()

    def start_worker(self, sess):
        self.sampler.start_worker()
        if self.plot:
            self.plotter = Plotter(self.env, self.policy, sess)
            self.plotter.start()

    def shutdown_worker(self):
        self.sampler.shutdown_worker()
        if self.plot:
            self.plotter.shutdown()

    def obtain_samples(self, itr):
        return self.sampler.obtain_samples(itr)

    def process_samples(self, itr, paths):
        return self.sampler.process_samples(itr, paths)

    def train(self, sess=None):
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        sess.run(tf.global_variables_initializer())
        self.start_worker(sess)
        start_time = time.time()
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            with logger.prefix('itr #%d | ' % itr):
                logger.log("Obtaining samples...")
                paths = self.obtain_samples(itr)
                logger.log("Processing samples...")
                samples_data = self.process_samples(itr, paths)
                logger.log("Logging diagnostics...")
                self.log_diagnostics(paths)
                logger.log("Optimizing policy...")
                self.optimize_policy(itr, samples_data)
                logger.log("Saving snapshot...")
                params = self.get_itr_snapshot(itr, samples_data)
                if self.store_paths:
                    params["paths"] = samples_data["paths"]
                logger.save_itr_params(itr, params)
                logger.log("Saved")
                logger.record_tabular('Time', time.time() - start_time)
                logger.record_tabular('ItrTime', time.time() - itr_start_time)
                logger.dump_tabular(with_prefix=False)
                if self.plot:
                    self.plotter.update_plot(self.policy, self.max_path_length)
                    if self.pause_for_plot:
                        input("Plotting evaluation run: Press Enter to "
                              "continue...")

        self.shutdown_worker()
        if created_session:
            sess.close()

    def log_diagnostics(self, paths):
        self.env.log_diagnostics(paths)
        self.policy.log_diagnostics(paths)
        self.baseline.log_diagnostics(paths)

    def init_opt(self):
        """
        Initialize the optimization procedure. If using tensorflow, this may
        include declaring all the variables and compiling functions
        """
        raise NotImplementedError

    def get_itr_snapshot(self, itr, samples_data):
        """
        Returns all the data that should be saved in the snapshot for this
        iteration.
        """
        raise NotImplementedError

    def optimize_policy(self, itr, samples_data):
        raise NotImplementedError
示例#3
0
class DDPG(RLAlgorithm):
    """
    A DDPG model based on https://arxiv.org/pdf/1509.02971.pdf.

    Example:
        $ python garage/examples/tf/ddpg_pendulum.py
    """

    def __init__(self,
                 env,
                 actor,
                 critic,
                 n_epochs=500,
                 n_epoch_cycles=20,
                 n_rollout_steps=100,
                 n_train_steps=50,
                 reward_scale=1.,
                 batch_size=64,
                 target_update_tau=0.01,
                 discount=0.99,
                 actor_lr=1e-4,
                 critic_lr=1e-3,
                 actor_weight_decay=0,
                 critic_weight_decay=0,
                 replay_buffer_size=int(1e6),
                 min_buffer_size=10000,
                 exploration_strategy=None,
                 plot=False,
                 pause_for_plot=False,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 use_her=False,
                 clip_obs=np.inf,
                 clip_pos_returns=True,
                 clip_return=None,
                 replay_k=4,
                 max_action=None,
                 name=None):
        """
        Construct class.

        Args:
            env(): Environment.
            actor(garage.tf.policies.ContinuousMLPPolicy): Policy network.
            critic(garage.tf.q_functions.ContinuousMLPQFunction):
        Q Value network.
            n_epochs(int, optional): Number of epochs.
            n_epoch_cycles(int, optional): Number of epoch cycles.
            n_rollout_steps(int, optional): Number of rollout steps.
        aka the time horizon of rollout.
            n_train_steps(int, optional): Number of train steps.
            reward_scale(float): The scaling factor applied to the rewards when
        training.
            batch_size(int): Number of samples for each minibatch.
            target_update_tau(float): Interpolation parameter for doing the
        soft target update.
            discount(float): Discount factor for the cumulative return.
            actor_lr(float): Learning rate for training policy network.
            critic_lr(float): Learning rate for training q value network.
            actor_weight_decay(float): L2 weight decay factor for parameters of
        the policy network.
            critic_weight_decay(float): L2 weight decay factor for parameters
        of the q value network.
            replay_buffer_size(int): Size of the replay buffer.
            min_buffer_size(int): Minimum size of the replay buffer to start
        training.
            exploration_strategy(): Exploration strategy to randomize the
        action.
            plot(boolean): Whether to visualize the policy performance after
        each eval_interval.
            pause_for_plot(boolean): Whether or not pause before continuing
        when plotting.
            actor_optimizer(): Optimizer for training policy network.
            critic_optimizer(): Optimizer for training q function network.
            use_her(boolean): Whether or not use HER for replay buffer.
            clip_obs(float): Clip observation to be in [-clip_obs, clip_obs].
            clip_pos_returns(boolean): Whether or not clip positive returns.
            clip_return(float): Clip return to be in [-clip_return,
        clip_return].
            replay_k(int): The ratio between HER replays and regular replays.
        Only used when use_her is True.
            max_action(float): Maximum action magnitude.
            name(str): Name of the algorithm shown in computation graph.
        """
        self.env = env

        self.input_dims = configure_dims(env)
        action_bound = env.action_space.high
        self.max_action = action_bound if max_action is None else max_action

        self.actor = actor
        self.critic = critic
        self.n_epochs = n_epochs
        self.n_epoch_cycles = n_epoch_cycles
        self.n_rollout_steps = n_rollout_steps
        self.n_train_steps = n_train_steps
        self.reward_scale = reward_scale
        self.batch_size = batch_size
        self.tau = target_update_tau
        self.discount = discount
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.actor_weight_decay = actor_weight_decay
        self.critic_weight_decay = critic_weight_decay
        self.replay_buffer_size = replay_buffer_size
        self.min_buffer_size = min_buffer_size
        self.es = exploration_strategy
        self.plot = plot
        self.pause_for_plot = pause_for_plot
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.name = name
        self.use_her = use_her
        self.evaluate = False
        self.replay_k = replay_k
        self.clip_return = (
            1. / (1. - self.discount)) if clip_return is None else clip_return
        self.clip_obs = clip_obs
        self.clip_pos_returns = clip_pos_returns
        self.success_history = deque(maxlen=100)
        self._initialize()

    @overrides
    def train(self, sess=None):
        """
        Training process of DDPG algorithm.

        Args:
            sess: A TensorFlow session for executing ops.
        """
        created_session = True if (sess is None) else False
        if sess is None:
            sess = tf.Session()
            sess.__enter__()

        # Start plotter
        if self.plot:
            self.plotter = Plotter(self.env, self.actor, sess)
            self.plotter.start()

        sess.run(tf.global_variables_initializer())
        self.f_init_target()

        observation = self.env.reset()
        if self.es:
            self.es.reset()

        episode_reward = 0.
        episode_step = 0
        episode_rewards = []
        episode_steps = []
        episode_actor_losses = []
        episode_critic_losses = []
        episodes = 0
        epoch_ys = []
        epoch_qs = []

        for epoch in range(self.n_epochs):
            logger.push_prefix('epoch #%d | ' % epoch)
            logger.log("Training started")
            self.success_history.clear()
            for epoch_cycle in pyprind.prog_bar(range(self.n_epoch_cycles)):
                if self.use_her:
                    successes = []
                    for rollout in range(self.n_rollout_steps):
                        o = np.clip(observation["observation"], -self.clip_obs,
                                    self.clip_obs)
                        g = np.clip(observation["desired_goal"],
                                    -self.clip_obs, self.clip_obs)
                        obs_goal = np.concatenate((o, g), axis=-1)
                        action = self.es.get_action(rollout, obs_goal,
                                                    self.actor)

                        next_observation, reward, terminal, info = self.env.step(  # noqa: E501
                            action)
                        if 'is_success' in info:
                            successes.append([info["is_success"]])
                        episode_reward += reward
                        episode_step += 1

                        info_dict = {
                            "info_{}".format(key): info[key].reshape(1)
                            for key in info.keys()
                        }
                        self.replay_buffer.add_transition(
                            observation=observation['observation'],
                            action=action,
                            goal=observation['desired_goal'],
                            achieved_goal=observation['achieved_goal'],
                            **info_dict,
                        )

                        observation = next_observation

                        if rollout == self.n_rollout_steps - 1:
                            self.replay_buffer.add_transition(
                                observation=observation['observation'],
                                achieved_goal=observation['achieved_goal'])

                            episode_rewards.append(episode_reward)
                            episode_steps.append(episode_step)
                            episode_reward = 0.
                            episode_step = 0
                            episodes += 1

                            observation = self.env.reset()
                            if self.es:
                                self.es.reset()

                    successful = np.array(successes)[-1, :]
                    success_rate = np.mean(successful)
                    self.success_history.append(success_rate)

                    for train_itr in range(self.n_train_steps):
                        self.evaluate = True
                        critic_loss, y, q, action_loss = self._learn()

                        episode_actor_losses.append(action_loss)
                        episode_critic_losses.append(critic_loss)
                        epoch_ys.append(y)
                        epoch_qs.append(q)

                    self.f_update_target()
                else:
                    for rollout in range(self.n_rollout_steps):
                        action = self.es.get_action(rollout, observation,
                                                    self.actor)
                        assert action.shape == self.env.action_space.shape

                        next_observation, reward, terminal, info = self.env.step(  # noqa: E501
                            action)
                        episode_reward += reward
                        episode_step += 1

                        self.replay_buffer.add_transition(
                            observation=observation,
                            action=action,
                            reward=reward * self.reward_scale,
                            terminal=terminal,
                            next_observation=next_observation,
                        )

                        observation = next_observation

                        if terminal or rollout == self.n_rollout_steps - 1:
                            episode_rewards.append(episode_reward)
                            episode_steps.append(episode_step)
                            episode_reward = 0.
                            episode_step = 0
                            episodes += 1

                            observation = self.env.reset()
                            if self.es:
                                self.es.reset()

                    for train_itr in range(self.n_train_steps):
                        if self.replay_buffer.size >= self.min_buffer_size:
                            self.evaluate = True
                            critic_loss, y, q, action_loss = self._learn()

                            episode_actor_losses.append(action_loss)
                            episode_critic_losses.append(critic_loss)
                            epoch_ys.append(y)
                            epoch_qs.append(q)

            logger.log("Training finished")
            logger.log("Saving snapshot")
            itr = epoch * self.n_epoch_cycles + epoch_cycle
            params = self.get_itr_snapshot(itr)
            logger.save_itr_params(itr, params)
            logger.log("Saved")
            if self.evaluate:
                logger.record_tabular('Epoch', epoch)
                logger.record_tabular('Episodes', episodes)
                logger.record_tabular('AverageReturn',
                                      np.mean(episode_rewards))
                logger.record_tabular('StdReturn', np.std(episode_rewards))
                logger.record_tabular('Policy/AveragePolicyLoss',
                                      np.mean(episode_actor_losses))
                logger.record_tabular('QFunction/AverageQFunctionLoss',
                                      np.mean(episode_critic_losses))
                logger.record_tabular('QFunction/AverageQ', np.mean(epoch_qs))
                logger.record_tabular('QFunction/MaxQ', np.max(epoch_qs))
                logger.record_tabular('QFunction/AverageAbsQ',
                                      np.mean(np.abs(epoch_qs)))
                logger.record_tabular('QFunction/AverageY', np.mean(epoch_ys))
                logger.record_tabular('QFunction/MaxY', np.max(epoch_ys))
                logger.record_tabular('QFunction/AverageAbsY',
                                      np.mean(np.abs(epoch_ys)))
                if self.use_her:
                    logger.record_tabular('AverageSuccessRate',
                                          np.mean(self.success_history))

                # Uncomment the following if you want to calculate the average
                # in each epoch, better uncomment when self.use_her is True
                # episode_rewards = []
                # episode_actor_losses = []
                # episode_critic_losses = []
                # epoch_ys = []
                # epoch_qs = []

            logger.dump_tabular(with_prefix=False)
            logger.pop_prefix()
            if self.plot:
                self.plotter.update_plot(self.actor, self.n_rollout_steps)
                if self.pause_for_plot:
                    input("Plotting evaluation run: Press Enter to "
                          "continue...")

        if self.plot:
            self.plotter.shutdown()
        if created_session:
            sess.close()

    def _initialize(self):
        with tf.name_scope(self.name, "DDPG"):
            with tf.name_scope("setup_networks"):
                """Set up the actor, critic and target network."""
                # Set up the actor and critic network
                self.actor._build_net(trainable=True)
                self.critic._build_net(trainable=True)

                # Create target actor and critic network
                target_actor = copy(self.actor)
                target_critic = copy(self.critic)

                # Set up the target network
                target_actor.name = "TargetActor"
                target_actor._build_net(trainable=False)
                target_critic.name = "TargetCritic"
                target_critic._build_net(trainable=False)

            input_shapes = dims_to_shapes(self.input_dims)

            # Initialize replay buffer
            if self.use_her:
                buffer_shapes = {
                    key: (self.n_rollout_steps + 1
                          if key == "observation" or key == "achieved_goal"
                          else self.n_rollout_steps, *input_shapes[key])
                    for key, val in input_shapes.items()
                }

                replay_buffer = HerReplayBuffer(
                    buffer_shapes=buffer_shapes,
                    size_in_transitions=self.replay_buffer_size,
                    time_horizon=self.n_rollout_steps,
                    sample_transitions=make_her_sample(
                        self.replay_k, self.env.compute_reward))
            else:
                replay_buffer = ReplayBuffer(
                    buffer_shapes=input_shapes,
                    max_buffer_size=self.replay_buffer_size)

            # Set up target init and update function
            with tf.name_scope("setup_target"):
                actor_init_ops, actor_update_ops = get_target_ops(
                    self.actor.global_vars, target_actor.global_vars, self.tau)
                critic_init_ops, critic_update_ops = get_target_ops(
                    self.critic.global_vars, target_critic.global_vars,
                    self.tau)
                target_init_op = actor_init_ops + critic_init_ops
                target_update_op = actor_update_ops + critic_update_ops

            f_init_target = tensor_utils.compile_function(
                inputs=[], outputs=target_init_op)
            f_update_target = tensor_utils.compile_function(
                inputs=[], outputs=target_update_op)

            with tf.name_scope("inputs"):
                obs_dim = (
                    self.input_dims["observation"] + self.input_dims["goal"]
                ) if self.use_her else self.input_dims["observation"]
                y = tf.placeholder(tf.float32, shape=(None, 1), name="input_y")
                obs = tf.placeholder(
                    tf.float32,
                    shape=(None, obs_dim),
                    name="input_observation")
                actions = tf.placeholder(
                    tf.float32,
                    shape=(None, self.input_dims["action"]),
                    name="input_action")

            # Set up actor training function
            next_action = self.actor.get_action_sym(obs, name="actor_action")
            next_qval = self.critic.get_qval_sym(
                obs, next_action, name="actor_qval")
            with tf.name_scope("action_loss"):
                action_loss = -tf.reduce_mean(next_qval)
                if self.actor_weight_decay > 0.:
                    actor_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.actor_weight_decay),
                        weights_list=self.actor.regularizable_vars)
                    action_loss += actor_reg

            with tf.name_scope("minimize_action_loss"):
                actor_train_op = self.actor_optimizer(
                    self.actor_lr, name="ActorOptimizer").minimize(
                        action_loss, var_list=self.actor.trainable_vars)

            f_train_actor = tensor_utils.compile_function(
                inputs=[obs], outputs=[actor_train_op, action_loss])

            # Set up critic training function
            qval = self.critic.get_qval_sym(obs, actions, name="q_value")
            with tf.name_scope("qval_loss"):
                qval_loss = tf.reduce_mean(tf.squared_difference(y, qval))
                if self.critic_weight_decay > 0.:
                    critic_reg = tc.layers.apply_regularization(
                        tc.layers.l2_regularizer(self.critic_weight_decay),
                        weights_list=self.critic.regularizable_vars)
                    qval_loss += critic_reg

            with tf.name_scope("minimize_critic_loss"):
                critic_train_op = self.critic_optimizer(
                    self.critic_lr, name="CriticOptimizer").minimize(
                        qval_loss, var_list=self.critic.trainable_vars)

            f_train_critic = tensor_utils.compile_function(
                inputs=[y, obs, actions],
                outputs=[critic_train_op, qval_loss, qval])

            self.f_train_actor = f_train_actor
            self.f_train_critic = f_train_critic
            self.f_init_target = f_init_target
            self.f_update_target = f_update_target
            self.replay_buffer = replay_buffer
            self.target_critic = target_critic
            self.target_actor = target_actor

    def _learn(self):
        """
        Perform algorithm optimizing.

        Returns:
            action_loss: Loss of action predicted by the policy network.
            qval_loss: Loss of q value predicted by the q network.
            ys: y_s.
            qval: Q value predicted by the q network.

        """
        if self.use_her:
            transitions = self.replay_buffer.sample(self.batch_size)
            observations = transitions["observation"]
            rewards = transitions["reward"]
            actions = transitions["action"]
            next_observations = transitions["next_observation"]
            goals = transitions["goal"]

            next_inputs = np.concatenate((next_observations, goals), axis=-1)
            inputs = np.concatenate((observations, goals), axis=-1)

            rewards = rewards.reshape(-1, 1)

            target_actions, _ = self.target_actor.get_actions(next_inputs)
            target_qvals = self.target_critic.get_qval(next_inputs,
                                                       target_actions)

            clip_range = (-self.clip_return, 0.
                          if self.clip_pos_returns else np.inf)
            ys = np.clip(rewards + self.discount * target_qvals, clip_range[0],
                         clip_range[1])

            _, qval_loss, qval = self.f_train_critic(ys, inputs, actions)
            _, action_loss = self.f_train_actor(inputs)
        else:
            transitions = self.replay_buffer.sample(self.batch_size)
            observations = transitions["observation"]
            rewards = transitions["reward"]
            actions = transitions["action"]
            terminals = transitions["terminal"]
            next_observations = transitions["next_observation"]

            rewards = rewards.reshape(-1, 1)
            terminals = terminals.reshape(-1, 1)

            target_actions, _ = self.target_actor.get_actions(
                next_observations)
            target_qvals = self.target_critic.get_qval(next_observations,
                                                       target_actions)

            ys = rewards + (1.0 - terminals) * self.discount * target_qvals

            _, qval_loss, qval = self.f_train_critic(ys, observations, actions)
            _, action_loss = self.f_train_actor(observations)
            self.f_update_target()

        return qval_loss, ys, qval, action_loss

    def get_itr_snapshot(self, itr):
        return dict(itr=itr, policy=self.actor, env=self.env)