Пример #1
0
class EntropyCoeffSchedule:
    """Mixin for TorchPolicy that adds entropy coeff decay."""

    @DeveloperAPI
    def __init__(self, entropy_coeff, entropy_coeff_schedule):
        self._entropy_coeff_schedule = None
        if entropy_coeff_schedule is None:
            self.entropy_coeff = entropy_coeff
        else:
            # Allows for custom schedule similar to lr_schedule format
            if isinstance(entropy_coeff_schedule, list):
                self._entropy_coeff_schedule = PiecewiseSchedule(
                    entropy_coeff_schedule,
                    outside_value=entropy_coeff_schedule[-1][-1],
                    framework=None,
                )
            else:
                # Implements previous version but enforces outside_value
                self._entropy_coeff_schedule = PiecewiseSchedule(
                    [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
                    outside_value=0.0,
                    framework=None,
                )
            self.entropy_coeff = self._entropy_coeff_schedule.value(0)

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
        if self._entropy_coeff_schedule is not None:
            self.entropy_coeff = self._entropy_coeff_schedule.value(
                global_vars["timestep"]
            )
Пример #2
0
class LearningRateSchedule:
    """Mixin for TFPolicy that adds a learning rate schedule."""
    @DeveloperAPI
    def __init__(self, lr, lr_schedule):
        self._lr_schedule = None
        if lr_schedule is None:
            self.cur_lr = tf1.get_variable("lr",
                                           initializer=lr,
                                           trainable=False)
        else:
            self._lr_schedule = PiecewiseSchedule(
                lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
            self.cur_lr = tf1.get_variable(
                "lr", initializer=self._lr_schedule.value(0), trainable=False)
            if self.framework == "tf":
                self._lr_placeholder = tf1.placeholder(dtype=tf.float32,
                                                       name="lr")
                self._lr_update = self.cur_lr.assign(self._lr_placeholder,
                                                     read_value=False)

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super(LearningRateSchedule, self).on_global_var_update(global_vars)
        if self._lr_schedule is not None:
            new_val = self._lr_schedule.value(global_vars["timestep"])
            if self.framework == "tf":
                self.get_session().run(
                    self._lr_update, feed_dict={self._lr_placeholder: new_val})
            else:
                self.cur_lr.assign(new_val, read_value=False)
                self._optimizer.learning_rate.assign(self.cur_lr)

    @override(TFPolicy)
    def optimizer(self):
        return tf1.train.AdamOptimizer(learning_rate=self.cur_lr)
Пример #3
0
class EntropyCoeffSchedule:
    """Mixin for TFPolicy that adds entropy coeff decay."""

    @DeveloperAPI
    def __init__(self, entropy_coeff, entropy_coeff_schedule):
        self._entropy_coeff_schedule = None
        if entropy_coeff_schedule is None:
            self.entropy_coeff = get_variable(
                entropy_coeff, framework="tf", tf_name="entropy_coeff", trainable=False
            )
        else:
            # Allows for custom schedule similar to lr_schedule format
            if isinstance(entropy_coeff_schedule, list):
                self._entropy_coeff_schedule = PiecewiseSchedule(
                    entropy_coeff_schedule,
                    outside_value=entropy_coeff_schedule[-1][-1],
                    framework=None,
                )
            else:
                # Implements previous version but enforces outside_value
                self._entropy_coeff_schedule = PiecewiseSchedule(
                    [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
                    outside_value=0.0,
                    framework=None,
                )

            self.entropy_coeff = get_variable(
                self._entropy_coeff_schedule.value(0),
                framework="tf",
                tf_name="entropy_coeff",
                trainable=False,
            )
            if self.framework == "tf":
                self._entropy_coeff_placeholder = tf1.placeholder(
                    dtype=tf.float32, name="entropy_coeff"
                )
                self._entropy_coeff_update = self.entropy_coeff.assign(
                    self._entropy_coeff_placeholder, read_value=False
                )

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super().on_global_var_update(global_vars)
        if self._entropy_coeff_schedule is not None:
            new_val = self._entropy_coeff_schedule.value(global_vars["timestep"])
            if self.framework == "tf":
                self.get_session().run(
                    self._entropy_coeff_update,
                    feed_dict={self._entropy_coeff_placeholder: new_val},
                )
            else:
                self.entropy_coeff.assign(new_val, read_value=False)
Пример #4
0
class EntropyCoeffSchedule:
    """Mixin for TFPolicy that adds entropy coeff decay."""
    @DeveloperAPI
    def __init__(self, entropy_coeff, entropy_coeff_schedule):
        self.entropy_coeff = tf.get_variable("entropy_coeff",
                                             initializer=entropy_coeff,
                                             trainable=False)

        if entropy_coeff_schedule is None:
            self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff)
        else:
            # Allows for custom schedule similar to lr_schedule format
            if isinstance(entropy_coeff_schedule, list):
                self.entropy_coeff_schedule = PiecewiseSchedule(
                    entropy_coeff_schedule,
                    outside_value=entropy_coeff_schedule[-1][-1])
            else:
                # Implements previous version but enforces outside_value
                self.entropy_coeff_schedule = PiecewiseSchedule(
                    [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
                    outside_value=0.0)

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
        self.entropy_coeff.load(self.entropy_coeff_schedule.value(
            global_vars["timestep"]),
                                session=self._sess)
Пример #5
0
class EntropyCoeffSchedule:
    """Mixin for TFPolicy that adds entropy coeff decay."""
    @DeveloperAPI
    def __init__(self, entropy_coeff, entropy_coeff_schedule):
        self.entropy_coeff = get_variable(entropy_coeff,
                                          framework="tf",
                                          tf_name="entropy_coeff",
                                          trainable=False)

        if entropy_coeff_schedule is None:
            self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff,
                                                           framework=None)
        else:
            # Allows for custom schedule similar to lr_schedule format
            if isinstance(entropy_coeff_schedule, list):
                self.entropy_coeff_schedule = PiecewiseSchedule(
                    entropy_coeff_schedule,
                    outside_value=entropy_coeff_schedule[-1][-1],
                    framework=None)
            else:
                # Implements previous version but enforces outside_value
                self.entropy_coeff_schedule = PiecewiseSchedule(
                    [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
                    outside_value=0.0,
                    framework=None)

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
        op_or_none = self.entropy_coeff.assign(
            self.entropy_coeff_schedule.value(global_vars["timestep"]),
            read_value=False,  # return tf op (None in eager mode).
        )
        if self._sess is not None:
            self._sess.run(op_or_none)
Пример #6
0
class ExtRewardCoeffSchedule:
    @DeveloperAPI
    def __init__(self, ext_reward_coeff, ext_reward_coeff_schedule):
        self.ext_reward_coeff = tf.get_variable(
            "ext_reward_coeff",
            initializer=float(ext_reward_coeff),
            trainable=False)

        if ext_reward_coeff_schedule is None:
            self.ext_reward_coeff_schedule = ConstantSchedule(ext_reward_coeff,
                                                              framework=None)
        else:
            # Allows for custom schedule similar to lr_schedule format
            if isinstance(ext_reward_coeff_schedule, list):
                self.ext_reward_coeff_schedule = PiecewiseSchedule(
                    ext_reward_coeff_schedule,
                    outside_value=ext_reward_coeff_schedule[-1][-1],
                    framework=None)
            else:
                # Implements previous version but enforces outside_value
                self.ext_reward_coeff_schedule = PiecewiseSchedule(
                    [[0, ext_reward_coeff], [ext_reward_coeff_schedule, 0.0]],
                    outside_value=0.0,
                    framework=None)

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super(ExtRewardCoeffSchedule, self).on_global_var_update(global_vars)
        self.ext_reward_coeff.load(self.ext_reward_coeff_schedule.value(
            global_vars["timestep"]),
                                   session=self._sess)
Пример #7
0
class LearningRateSchedule:
    """Mixin for TFPolicy that adds a learning rate schedule."""
    @DeveloperAPI
    def __init__(self, lr, lr_schedule):
        self._lr_schedule = None
        if lr_schedule is None:
            self.cur_lr = lr
        else:
            self._lr_schedule = PiecewiseSchedule(
                lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
            self.cur_lr = self._lr_schedule.value(0)

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super().on_global_var_update(global_vars)
        if self._lr_schedule:
            self.cur_lr = self._lr_schedule.value(global_vars["timestep"])
            for opt in self._optimizers:
                for p in opt.param_groups:
                    p["lr"] = self.cur_lr
Пример #8
0
class LearningRateSchedule(object):
    """Mixin for TFPolicyGraph that adds a learning rate schedule."""
    def __init__(self, lr, lr_schedule):
        self.cur_lr = tf.get_variable("lr", initializer=lr)
        if lr_schedule is None:
            self.lr_schedule = ConstantSchedule(lr)
        else:
            self.lr_schedule = PiecewiseSchedule(
                lr_schedule, outside_value=lr_schedule[-1][-1])

    def on_global_var_update(self, global_vars):
        super(LearningRateSchedule, self).on_global_var_update(global_vars)
        self.cur_lr.load(self.lr_schedule.value(global_vars["timestep"]),
                         session=self._sess)

    def optimizer(self):
        return tf.train.AdamOptimizer(self.cur_lr)
Пример #9
0
class ManualLearningRateSchedule:
    """Mixin for TFPolicy that adds a learning rate schedule."""
    def __init__(self, lr, lr_schedule):
        self.cur_lr = lr
        if lr_schedule is None:
            self.lr_schedule = ConstantSchedule(lr, framework=None)
        else:
            self.lr_schedule = PiecewiseSchedule(
                lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)

    # not called automatically by any rllib logic, call this in your training script or a trainer callback
    def update_lr(self, timesteps_total):
        print(f"cur lr {self.cur_lr}")
        self.cur_lr = self.lr_schedule.value(timesteps_total)
        for opt in self._optimizers:
            for p in opt.param_groups:
                p["lr"] = self.cur_lr
Пример #10
0
class LearningRateSchedule:
    """Mixin for TFPolicy that adds a learning rate schedule."""
    @DeveloperAPI
    def __init__(self, lr, lr_schedule):
        self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False)
        if lr_schedule is None:
            self.lr_schedule = ConstantSchedule(lr, framework=None)
        else:
            self.lr_schedule = PiecewiseSchedule(
                lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super(LearningRateSchedule, self).on_global_var_update(global_vars)
        self.cur_lr.load(self.lr_schedule.value(global_vars["timestep"]),
                         session=self._sess)

    @override(TFPolicy)
    def optimizer(self):
        return tf1.train.AdamOptimizer(learning_rate=self.cur_lr)
Пример #11
0
class LearningRateSchedule:
    """Mixin for TFPolicy that adds a learning rate schedule."""
    @DeveloperAPI
    def __init__(self, lr, lr_schedule):
        self.cur_lr = lr
        if lr_schedule is None:
            self.lr_schedule = ConstantSchedule(lr, framework=None)
        else:
            self.lr_schedule = PiecewiseSchedule(
                lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super(LearningRateSchedule, self).on_global_var_update(global_vars)
        self.cur_lr = self.lr_schedule.value(global_vars["timestep"])

    @override(TorchPolicy)
    def optimizer(self):
        for p in self._optimizer.param_groups:
            p["lr"] = self.cur_lr
        return self._optimizer
Пример #12
0
class LearningRateSchedule(object):
    """Mixin for TFPolicyGraph that adds a learning rate schedule."""

    def __init__(self, lr, lr_schedule):
        self.cur_lr = tf.get_variable("lr", initializer=lr)
        if lr_schedule is None:
            self.lr_schedule = ConstantSchedule(lr)
        else:
            self.lr_schedule = PiecewiseSchedule(
                lr_schedule, outside_value=lr_schedule[-1][-1])

    @override(PolicyGraph)
    def on_global_var_update(self, global_vars):
        super(LearningRateSchedule, self).on_global_var_update(global_vars)
        self.cur_lr.load(
            self.lr_schedule.value(global_vars["timestep"]),
            session=self._sess)

    @override(TFPolicyGraph)
    def optimizer(self):
        return tf.train.AdamOptimizer(self.cur_lr)
class LearningRateSchedule:
    """Mixin for TFPolicy that adds a learning rate schedule."""

    def __init__(self, lr, lr_schedule):
        self.cur_lr = tf.Variable(lr, name="lr", trainable=False)
        # self.cur_lr = tf.get_variable("lr", initializer=lr, trainable=False)
        if lr_schedule is None:
            self.lr_schedule = ConstantSchedule(lr, framework=None)
        else:
            self.lr_schedule = PiecewiseSchedule(
                lr_schedule,
                interpolation=_left_constant_interpolation,
                outside_value=lr_schedule[-1][-1],
                framework=None,
            )

    @override(Policy)
    def on_global_var_update(self, global_vars):
        super(LearningRateSchedule, self).on_global_var_update(global_vars)
        self.cur_lr.load(
            self.lr_schedule.value(global_vars["timestep"]), session=self._sess
        )
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(
        self,
        workers,
        learning_starts=1000,
        buffer_size=10000,
        prioritized_replay=True,
        prioritized_replay_alpha=0.6,
        prioritized_replay_beta=0.4,
        prioritized_replay_eps=1e-6,
        final_prioritized_replay_beta=0.4,
        train_batch_size=32,
        before_learn_on_batch=None,
        synchronize_sampling=False,
        prioritized_replay_beta_annealing_timesteps=100000 * 0.2,
    ):
        """Initialize an sync replay optimizer.

        Args:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            final_prioritized_replay_beta (float): Final value of beta.
            train_batch_size (int): size of batches to learn on
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
            prioritized_replay_beta_annealing_timesteps (int): The timestep at
                which PR-beta annealing should end.
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts

        # Linearly annealing beta used in Rainbow paper, stopping at
        # `final_prioritized_replay_beta`.
        self.prioritized_replay_beta = PiecewiseSchedule(
            endpoints=[(0, prioritized_replay_beta),
                       (prioritized_replay_beta_annealing_timesteps,
                        final_prioritized_replay_beta)],
            outside_value=final_prioritized_replay_beta,
            framework=None)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batch = SampleBatch.concat_samples(
                    ray_get_and_free([
                        e.sample.remote()
                        for e in self.workers.remote_workers()
                    ]))
            else:
                batch = self.workers.local_worker().sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    # TODO(sven): This is currently structured differently for
                    #  torch/tf. Clean up these results/info dicts across
                    #  policies (note: fixing this in torch_policy.py will
                    #  break e.g. DDPPO!).
                    td_error = info.get("td_error",
                                        info["learner_stats"].get("td_error"))
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        idxes = None
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(
        self,
        workers,
        learning_starts=1000,
        buffer_size=10000,
        prioritized_replay=True,
        prioritized_replay_alpha=0.6,
        prioritized_replay_beta=0.4,
        prioritized_replay_eps=1e-6,
        final_prioritized_replay_beta=0.4,
        train_batch_size=32,
        before_learn_on_batch=None,
        synchronize_sampling=False,
        prioritized_replay_beta_annealing_timesteps=100000 * 0.2,
    ):
        """Initialize an sync replay optimizer.

        Args:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            final_prioritized_replay_beta (float): Final value of beta.
            train_batch_size (int): size of batches to learn on
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
            prioritized_replay_beta_annealing_timesteps (int): The timestep at
                which PR-beta annealing should end.
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts

        # Linearly annealing beta used in Rainbow paper, stopping at
        # `final_prioritized_replay_beta`.
        self.prioritized_replay_beta = PiecewiseSchedule(
            endpoints=[(0, prioritized_replay_beta),
                       (prioritized_replay_beta_annealing_timesteps,
                        final_prioritized_replay_beta)],
            outside_value=final_prioritized_replay_beta,
            framework=None)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batch = SampleBatch.concat_samples(
                    ray_get_and_free([
                        e.sample.remote()
                        for e in self.workers.remote_workers()
                    ]))
            else:
                batch = self.workers.local_worker().sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        idxes = None
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)

    def save(self):
        f = open(
            "/home/yunke/prl_proj/panda_ws/src/franka_cal_sim/python/replay_buffer.txt",
            "w")
        for policy_id, replay_buffer in self.replay_buffers.items():
            for data in replay_buffer._storage:
                obs_t, action, reward, obs_tp1, done, weight = data
                obs_s = ','.join([str(v) for v in obs_t])
                action = ','.join([str(v) for v in action])
                obs_tp1 = ','.join([str(v) for v in obs_tp1])
                f.write("%s\t%s\t%s\t%s\t%s\t%s\n" %
                        (obs_s, action, reward, obs_tp1, done, weight))
        f.close()

    def restore(self):
        f = open(
            "/home/yunke/prl_proj/panda_ws/src/franka_cal_sim/python/replay_buffer.txt",
            "r")
        obs, actions, rewards, next_obs, terminals, weights = [],[],[],[],[],[]
        for line in f:
            cols = line.strip().split('\t')
            obs_t = np.array([float(v) for v in cols[0].split(',')])
            obs.append(obs_t)
            action = np.array([float(v) for v in cols[1].split(',')])
            actions.append(action)
            rewards.append(float(cols[2]))
            obs_tp1 = np.array([float(v) for v in cols[3].split(',')])
            next_obs.append(obs_tp1)
            terminals.append(bool(cols[4]))
            weights.append(float(cols[5]))

        batch = SampleBatch({
            "obs": obs,
            "actions": actions,
            "rewards": rewards,
            "new_obs": next_obs,
            "dones": terminals,
            "weights": weights
        })

        for i in range(obs_s.shape[0]):
            self.replay_buffers[policy_id].add(pack_if_needed(obs_s[i]),
                                               actions[i], rewards[i],
                                               pack_if_needed(new_obs[i]),
                                               dones[i])