def _init(self, num_sgd_iter=1, train_batch_size=1):
     self.update_weights_timer = TimerStat()
     self.sample_timer = TimerStat()
     self.grad_timer = TimerStat()
     self.throughput = RunningStat()
     self.num_sgd_iter = num_sgd_iter
     self.train_batch_size = train_batch_size
     self.learner_stats = {}
Exemple #2
0
    def _init(
            self, learning_starts=1000, buffer_size=10000,
            prioritized_replay=True, prioritized_replay_alpha=0.6,
            prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6,
            train_batch_size=32, sample_batch_size=4, clip_rewards=True):

        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()

        # Set up replay buffer
        if prioritized_replay:
            self.replay_buffer = PrioritizedReplayBuffer(
                buffer_size, alpha=prioritized_replay_alpha,
                clip_rewards=clip_rewards)
        else:
            self.replay_buffer = ReplayBuffer(buffer_size, clip_rewards)

        assert buffer_size >= self.replay_starts
Exemple #3
0
class LocalSyncOptimizer(Optimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    evaluators, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote evaluators.
    """

    def _init(self, batch_size=32):
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.batch_size = batch_size

    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                samples = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                samples = self.local_evaluator.sample()

        with self.grad_timer:
            grad = self.local_evaluator.compute_gradients(samples)
            self.local_evaluator.apply_gradients(grad)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count

    def stats(self):
        return dict(Optimizer.stats(self), **{
            "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
            "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
            "update_time_ms": round(1000 * self.update_weights_timer.mean, 3),
            "opt_peak_throughput": round(self.grad_timer.mean_throughput, 3),
            "opt_samples": round(self.grad_timer.mean_units_processed, 3),
        })
Exemple #4
0
    def __init__(self, workers, train_batch_size=10000, microbatch_size=1000):
        PolicyOptimizer.__init__(self, workers)

        if train_batch_size <= microbatch_size:
            raise ValueError(
                "The microbatch size must be smaller than the train batch "
                "size, got {} vs {}".format(microbatch_size, train_batch_size))

        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.train_batch_size = train_batch_size
        self.microbatch_size = microbatch_size
        self.learner_stats = {}
        self.policies = dict(
            self.workers.local_worker().foreach_trainable_policy(lambda p, i:
                                                                 (i, p)))
        logger.debug("Policies to train: {}".format(self.policies))
    def __init__(self, num_shards, learning_starts, buffer_size,
                 train_batch_size):
        self.replay_starts = learning_starts // num_shards
        self.buffer_size = buffer_size // num_shards
        self.train_batch_size = train_batch_size

        # self.prioritized_replay_beta = prioritized_replay_beta
        # self.prioritized_replay_eps = prioritized_replay_eps

        def new_buffer():
            return ReplayBuffer(self.buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        # Metrics
        self.add_batch_timer = TimerStat()
        self.replay_timer = TimerStat()
        # self.update_priorities_timer = TimerStat()
        self.num_added = 0
    def __init__(self, num_shards, learning_starts, buffer_size,
                 train_batch_size, prioritized_replay_alpha,
                 prioritized_replay_beta, prioritized_replay_eps,
                 clip_rewards):
        self.replay_starts = learning_starts // num_shards
        self.buffer_size = buffer_size // num_shards
        self.train_batch_size = train_batch_size
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps

        self.replay_buffer = PrioritizedReplayBuffer(
            self.buffer_size,
            alpha=prioritized_replay_alpha,
            clip_rewards=clip_rewards)

        # Metrics
        self.add_batch_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.update_priorities_timer = TimerStat()
Exemple #7
0
 def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter,
              learner_queue_size, learner_queue_timeout):
     threading.Thread.__init__(self)
     self.learner_queue_size = WindowStat("size", 50)
     self.local_worker = local_worker
     self.inqueue = queue.Queue(maxsize=learner_queue_size)
     self.outqueue = queue.Queue()
     self.minibatch_buffer = MinibatchBuffer(inqueue=self.inqueue,
                                             size=minibatch_buffer_size,
                                             timeout=learner_queue_timeout,
                                             num_passes=num_sgd_iter)
     self.queue_timer = TimerStat()
     self.grad_timer = TimerStat()
     self.load_timer = TimerStat()
     self.load_wait_timer = TimerStat()
     self.daemon = True
     self.weights_updated = False
     self.stats = {}
     self.stopped = False
Exemple #8
0
    def __init__(
        self,
        local_worker: RolloutWorker,
        minibatch_buffer_size: int,
        num_sgd_iter: int,
        learner_queue_size: int,
        learner_queue_timeout: int,
    ):
        """Initialize the learner thread.

        Args:
            local_worker: process local rollout worker holding
                policies this thread will call learn_on_batch() on
            minibatch_buffer_size: max number of train batches to store
                in the minibatching buffer
            num_sgd_iter: number of passes to learn on per train batch
            learner_queue_size: max size of queue of inbound
                train batches to this thread
            learner_queue_timeout: raise an exception if the queue has
                been empty for this long in seconds
        """
        threading.Thread.__init__(self)
        self.learner_queue_size = WindowStat("size", 50)
        self.local_worker = local_worker
        self.inqueue = queue.Queue(maxsize=learner_queue_size)
        self.outqueue = queue.Queue()
        self.minibatch_buffer = MinibatchBuffer(
            inqueue=self.inqueue,
            size=minibatch_buffer_size,
            timeout=learner_queue_timeout,
            num_passes=num_sgd_iter,
            init_num_passes=num_sgd_iter,
        )
        self.queue_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.load_timer = TimerStat()
        self.load_wait_timer = TimerStat()
        self.daemon = True
        self.weights_updated = False
        self.learner_info = {}
        self.stopped = False
        self.num_steps = 0
Exemple #9
0
    def _init(
            self, learning_starts=1000, buffer_size=10000,
            prioritized_replay=True, prioritized_replay_alpha=0.6,
            prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6,
            train_batch_size=512, sample_batch_size=50,
            num_replay_buffer_shards=1, max_weight_sync_delay=400,
            clip_rewards=True, debug=False):

        self.debug = debug
        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.sample_batch_size = sample_batch_size
        self.max_weight_sync_delay = max_weight_sync_delay

        self.learner = LearnerThread(self.local_evaluator)
        self.learner.start()

        # TODO(ekl) use create_colocated() for these actors once
        # https://github.com/ray-project/ray/issues/1734 is fixed
        self.replay_actors = [
            ReplayActor.remote(
                num_replay_buffer_shards, learning_starts, buffer_size,
                train_batch_size, prioritized_replay_alpha,
                prioritized_replay_beta, prioritized_replay_eps, clip_rewards)
            for _ in range(num_replay_buffer_shards)
        ]
        assert len(self.remote_evaluators) > 0

        # Stats
        self.timers = {k: TimerStat() for k in [
            "put_weights", "get_samples", "enqueue", "sample_processing",
            "replay_processing", "update_priorities", "train", "sample"]}
        self.num_weight_syncs = 0
        self.learning_started = False

        # Number of worker steps since the last weight update
        self.steps_since_update = {}

        # Otherwise kick of replay tasks for local gradient updates
        self.replay_tasks = TaskPool()
        for ra in self.replay_actors:
            for _ in range(REPLAY_QUEUE_DEPTH):
                self.replay_tasks.add(ra, ra.replay.remote())

        # Kick off async background sampling
        self.sample_tasks = TaskPool()
        weights = self.local_evaluator.get_weights()
        for ev in self.remote_evaluators:
            ev.set_weights.remote(weights)
            self.steps_since_update[ev] = 0
            for _ in range(SAMPLE_QUEUE_DEPTH):
                self.sample_tasks.add(ev, ev.sample.remote())
    def __init__(self,
                 workers,
                 num_sgd_iter=1,
                 train_batch_size=1,
                 sgd_minibatch_size=0,
                 standardize_fields=frozenset([])):
        PolicyOptimizer.__init__(self, workers)

        self.update_weights_timer = TimerStat()
        self.standardize_fields = standardize_fields
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.sgd_minibatch_size = sgd_minibatch_size
        self.train_batch_size = train_batch_size
        self.learner_stats = {}
        self.policies = dict(self.workers.local_worker()
                             .foreach_trainable_policy(lambda p, i: (i, p)))
        logger.debug("Policies to train: {}".format(self.policies))
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              prioritized_replay=True,
              prioritized_replay_alpha=0.6,
              prioritized_replay_beta=0.4,
              prioritized_replay_eps=1e-6,
              train_batch_size=512,
              sample_batch_size=50,
              num_replay_buffer_shards=1,
              max_weight_sync_delay=400,
              clip_rewards=True,
              debug=False):

        self.debug = debug
        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.max_weight_sync_delay = max_weight_sync_delay

        self.learner = LearnerThread(self.local_evaluator)
        self.learner.start()

        self.replay_actors = create_colocated(ReplayActor, [
            num_replay_buffer_shards, learning_starts, buffer_size,
            train_batch_size, prioritized_replay_alpha,
            prioritized_replay_beta, prioritized_replay_eps, clip_rewards
        ], num_replay_buffer_shards)

        # Stats
        self.timers = {
            k: TimerStat()
            for k in [
                "put_weights", "get_samples", "enqueue", "sample_processing",
                "replay_processing", "update_priorities", "train", "sample"
            ]
        }
        self.num_weight_syncs = 0
        self.learning_started = False

        # Number of worker steps since the last weight update
        self.steps_since_update = {}

        # Otherwise kick of replay tasks for local gradient updates
        self.replay_tasks = TaskPool()
        for ra in self.replay_actors:
            for _ in range(REPLAY_QUEUE_DEPTH):
                self.replay_tasks.add(ra, ra.replay.remote())

        # Kick off async background sampling
        self.sample_tasks = TaskPool()
        if self.remote_evaluators:
            self.set_evaluators(self.remote_evaluators)
    def __init__(self,
                 workers,
                 learning_starts=1000,
                 buffer_size=10000,
                 train_batch_size=32):
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts
        self.max_buffer_size = buffer_size
        self.train_batch_size = train_batch_size
        assert self.max_buffer_size >= self.replay_starts

        # List of buffered sample batches
        self.replay_buffer = []
        self.buffer_size = 0

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}
    def __init__(self,
                 num_shards,
                 learning_starts,
                 buffer_size,
                 replay_batch_size,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 multiagent_sync_replay=False):
        self.replay_starts = learning_starts // num_shards
        self.buffer_size = buffer_size // num_shards
        self.replay_batch_size = replay_batch_size
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.multiagent_sync_replay = multiagent_sync_replay

        def gen_replay():
            while True:
                yield self.replay()

        ParallelIteratorWorker.__init__(self, gen_replay, False)

        def new_buffer():
            return PrioritizedReplayBuffer(self.buffer_size,
                                           alpha=prioritized_replay_alpha)

        self.replay_buffers = collections.defaultdict(new_buffer)

        # Metrics
        self.add_batch_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.update_priorities_timer = TimerStat()
        self.num_added = 0

        # Make externally accessible for testing.
        global _local_replay_buffer
        _local_replay_buffer = self
        # If set, return this instead of the usual data for testing.
        self._fake_batch = None
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              prioritized_replay=True,
              prioritized_replay_alpha=0.6,
              prioritized_replay_beta=0.4,
              prioritized_replay_eps=1e-6,
              train_batch_size=32,
              sample_batch_size=4,
              clip_rewards=True):

        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(
                    buffer_size,
                    alpha=prioritized_replay_alpha,
                    clip_rewards=clip_rewards)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size, clip_rewards)

        self.replay_buffers = collections.defaultdict(new_buffer)

        assert buffer_size >= self.replay_starts
    def __init__(
            self,
            num_shards,
            learning_starts,
            buffer_size,
            train_batch_size,
            prioritized_replay_alpha,
            prioritized_replay_beta,
            prioritized_replay_eps,
            # added for dynamic experience replay
            human_demonstration,
            multiple_human_data,
            human_data_dir,
            dynamic_experience_replay,
            demonstration_zone_percentage,
            robot_demo_path):
        self.replay_starts = learning_starts // num_shards
        self.buffer_size = buffer_size // num_shards
        self.train_batch_size = train_batch_size
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps

        # Metrics
        self.add_batch_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.update_priorities_timer = TimerStat()
        self.num_added = 0

        # added for dynamic experience replay
        self.prioritized_replay_alpha = prioritized_replay_alpha
        self.human_demonstration = human_demonstration
        self.multiple_human_data = multiple_human_data
        self.human_data_dir = human_data_dir
        self.dynamic_experience_replay = dynamic_experience_replay
        self.demonstration_zone_percentage = demonstration_zone_percentage
        self.robot_demo_path = robot_demo_path
        self.load_human_demo()
Exemple #16
0
    def _init(self, sgd_batch_size=128, sgd_stepsize=5e-5, num_sgd_iter=10):
        assert isinstance(self.local_evaluator, TFMultiGPUSupport)
        self.batch_size = sgd_batch_size
        self.sgd_stepsize = sgd_stepsize
        self.num_sgd_iter = num_sgd_iter
        gpu_ids = ray.get_gpu_ids()
        if not gpu_ids:
            self.devices = ["/cpu:0"]
        else:
            self.devices = ["/gpu:{}".format(i) for i in range(len(gpu_ids))]
        assert self.batch_size > len(self.devices), "batch size too small"
        self.per_device_batch_size = self.batch_size // len(self.devices)
        self.sample_timer = TimerStat()
        self.load_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.update_weights_timer = TimerStat()

        print("LocalMultiGPUOptimizer devices", self.devices)
        print("LocalMultiGPUOptimizer batch size", self.batch_size)

        # List of (feature name, feature placeholder) tuples
        self.loss_inputs = self.local_evaluator.tf_loss_inputs()

        # per-GPU graph copies created below must share vars with the policy
        tf.get_variable_scope().reuse_variables()

        self.par_opt = LocalSyncParallelOptimizer(
            tf.train.AdamOptimizer(self.sgd_stepsize),
            self.devices,
            [ph for _, ph in self.loss_inputs],
            self.per_device_batch_size,
            lambda *ph: self.local_evaluator.build_tf_loss(ph),
            os.getcwd())

        self.sess = self.local_evaluator.sess
        self.sess.run(tf.global_variables_initializer())
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              train_batch_size=32):
        self.replay_starts = learning_starts
        self.max_buffer_size = buffer_size
        self.train_batch_size = train_batch_size
        assert self.max_buffer_size >= self.replay_starts

        # List of buffered sample batches
        self.replay_buffer = []
        self.buffer_size = 0

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              prioritized_replay=True,
              prioritized_replay_alpha=0.6,
              prioritized_replay_beta=0.4,
              schedule_max_timesteps=100000,
              beta_annealing_fraction=0.2,
              final_prioritized_replay_beta=0.4,
              prioritized_replay_eps=1e-6,
              train_batch_size=32,
              sample_batch_size=4):

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(
                schedule_max_timesteps * beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(
                    buffer_size, alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        assert buffer_size >= self.replay_starts
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(
        self,
        workers,
        learning_starts=1000,
        buffer_size=10000,
        prioritized_replay=True,
        prioritized_replay_alpha=0.6,
        prioritized_replay_beta=0.4,
        prioritized_replay_eps=1e-6,
        final_prioritized_replay_beta=0.4,
        train_batch_size=32,
        before_learn_on_batch=None,
        synchronize_sampling=False,
        prioritized_replay_beta_annealing_timesteps=100000 * 0.2,
    ):
        """Initialize an sync replay optimizer.

        Args:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            final_prioritized_replay_beta (float): Final value of beta.
            train_batch_size (int): size of batches to learn on
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
            prioritized_replay_beta_annealing_timesteps (int): The timestep at
                which PR-beta annealing should end.
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts

        # Linearly annealing beta used in Rainbow paper, stopping at
        # `final_prioritized_replay_beta`.
        self.prioritized_replay_beta = PiecewiseSchedule(
            endpoints=[(0, prioritized_replay_beta),
                       (prioritized_replay_beta_annealing_timesteps,
                        final_prioritized_replay_beta)],
            outside_value=final_prioritized_replay_beta,
            framework=None)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batch = SampleBatch.concat_samples(
                    ray_get_and_free([
                        e.sample.remote()
                        for e in self.workers.remote_workers()
                    ]))
            else:
                batch = self.workers.local_worker().sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    # TODO(sven): This is currently structured differently for
                    #  torch/tf. Clean up these results/info dicts across
                    #  policies (note: fixing this in torch_policy.py will
                    #  break e.g. DDPPO!).
                    td_error = info.get("td_error",
                                        info["learner_stats"].get("td_error"))
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        idxes = None
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)
    def __init__(self,
                 workers,
                 learning_starts=1000,
                 buffer_size=10000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 schedule_max_timesteps=100000,
                 beta_annealing_fraction=0.2,
                 final_prioritized_replay_beta=0.4,
                 train_batch_size=32,
                 sample_batch_size=4,
                 before_learn_on_batch=None,
                 synchronize_sampling=False):
        """Initialize an sync replay optimizer.

        Arguments:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            schedule_max_timesteps (int): number of timesteps in the schedule
            beta_annealing_fraction (float): fraction of schedule to anneal
                beta over
            final_prioritized_replay_beta (float): final value of beta
            train_batch_size (int): size of batches to learn on
            sample_batch_size (int): size of batches to sample from workers
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(schedule_max_timesteps *
                                   beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))
Exemple #21
0
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that policy evaluators return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(self,
                 local_evaluator,
                 remote_evaluators,
                 learning_starts=1000,
                 buffer_size=10000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 schedule_max_timesteps=100000,
                 beta_annealing_fraction=0.2,
                 final_prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 train_batch_size=32,
                 sample_batch_size=4):
        PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(schedule_max_timesteps *
                                   beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray_get_and_free(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            info_dict = self.local_evaluator.learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample(
                         self.train_batch_size,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample(self.train_batch_size)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)
    def __init__(self,
                 workers,
                 config,
                 learning_starts=1000,
                 buffer_size=50000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 schedule_max_timesteps=100000,
                 beta_annealing_fraction=0.2,
                 final_prioritized_replay_beta=0.4,
                 train_batch_size=32,
                 sample_batch_size=4,
                 before_learn_on_batch=None,
                 synchronize_sampling=False):
        """Initialize an sync replay optimizer.

        Arguments:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            schedule_max_timesteps (int): number of timesteps in the schedule
            beta_annealing_fraction (float): fraction of schedule to anneal
                beta over
            final_prioritized_replay_beta (float): final value of beta
            train_batch_size (int): size of batches to learn on
            sample_batch_size (int): size of batches to sample from workers
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(schedule_max_timesteps *
                                   beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}
        '''Attention Info'''
        self.traffic_light_node_dict = {}
        self.record_dir = '/home/skylark/PycharmRemote/Gamma-Reward-Perfect/record/' + config[
            "env_config"]["Name"]
        self.read_traffic_light_node_dict()
        self.tmp_dic = self.traffic_light_node_dict['intersection_1_1'][
            'inter_id_to_index']
        # -------------------------------------------
        '''
        For compare reward change 
        '''
        self.raw_reward_store = {}
        self.Reward_store = {}
        for inter_id in self.tmp_dic:
            self.raw_reward_store[inter_id] = []
            self.Reward_store[inter_id] = []
        # self.j_store = 0
        # ------------------------------
        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))
        '''
        For Gamma Reward by Skylark
        '''
        self.memory_thres = config["env_config"]["memory_thres"]
        self.num_steps_presampled = 0
        self.gamma = 0.5
        self.index = 0
        self.punish_coeff = 1.5
        self.config = config
        # Set up replay buffer
        if prioritized_replay:

            def pre_new_buffer():
                return PrioritizedReplayBuffer(buffer_size + self.memory_thres,
                                               alpha=prioritized_replay_alpha)
        else:

            def pre_new_buffer():
                return ReplayBuffer(buffer_size + self.memory_thres)

        self.pre_replay_buffers = collections.defaultdict(pre_new_buffer)
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(self,
                 workers,
                 config,
                 learning_starts=1000,
                 buffer_size=50000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 schedule_max_timesteps=100000,
                 beta_annealing_fraction=0.2,
                 final_prioritized_replay_beta=0.4,
                 train_batch_size=32,
                 sample_batch_size=4,
                 before_learn_on_batch=None,
                 synchronize_sampling=False):
        """Initialize an sync replay optimizer.

        Arguments:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            schedule_max_timesteps (int): number of timesteps in the schedule
            beta_annealing_fraction (float): fraction of schedule to anneal
                beta over
            final_prioritized_replay_beta (float): final value of beta
            train_batch_size (int): size of batches to learn on
            sample_batch_size (int): size of batches to sample from workers
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(schedule_max_timesteps *
                                   beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}
        '''Attention Info'''
        self.traffic_light_node_dict = {}
        self.record_dir = '/home/skylark/PycharmRemote/Gamma-Reward-Perfect/record/' + config[
            "env_config"]["Name"]
        self.read_traffic_light_node_dict()
        self.tmp_dic = self.traffic_light_node_dict['intersection_1_1'][
            'inter_id_to_index']
        # -------------------------------------------
        '''
        For compare reward change 
        '''
        self.raw_reward_store = {}
        self.Reward_store = {}
        for inter_id in self.tmp_dic:
            self.raw_reward_store[inter_id] = []
            self.Reward_store[inter_id] = []
        # self.j_store = 0
        # ------------------------------
        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))
        '''
        For Gamma Reward by Skylark
        '''
        self.memory_thres = config["env_config"]["memory_thres"]
        self.num_steps_presampled = 0
        self.gamma = 0.5
        self.index = 0
        self.punish_coeff = 1.5
        self.config = config
        # Set up replay buffer
        if prioritized_replay:

            def pre_new_buffer():
                return PrioritizedReplayBuffer(buffer_size + self.memory_thres,
                                               alpha=prioritized_replay_alpha)
        else:

            def pre_new_buffer():
                return ReplayBuffer(buffer_size + self.memory_thres)

        self.pre_replay_buffers = collections.defaultdict(pre_new_buffer)
        # ------------------------------------------

        # '''
        # For Attention Reward by Skylark
        # '''
        # sa_size = [(15, 8), (15, 8), (15, 8), (15, 8), (15, 8), (15, 8)]
        # critic_hidden_dim = 128
        # attend_heads = 4
        # q_lr = 0.01
        # self.attention = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim,
        #                                  attend_heads=attend_heads)
        # self.target_attention = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim,
        #                                         attend_heads=attend_heads)
        # hard_update(self.target_attention, self.attention)
        # self.attention_optimizer = Adam(self.attention.parameters(), lr=q_lr,
        #                                 weight_decay=1e-3)
        # self.niter = 0
        # ------------------------------------------------------------------

    @override(PolicyOptimizer)
    def step(self, attention_score_dic=None):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batch = SampleBatch.concat_samples(
                    ray_get_and_free([
                        e.sample.remote()
                        for e in self.workers.remote_workers()
                    ]))
            else:
                batch = self.workers.local_worker().sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)
            '''
            For Gamma Reward by LJJ (You can check the local history for changing)
            '''
            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.pre_replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

            if self.num_steps_presampled >= self.memory_thres:
                self._preprocess(batch, attention_score_dic)

            self.num_steps_presampled += batch.count

        # -----------------------------------------------------------------------

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        idxes = None
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)

    def _preprocess(self, batch, attention_score_dic=None):
        """
        Self-defined function: For Gamma Reward Replay Buffer Amendment
        :param batch: SampleBatch class,
        :param attention_score_dic: For transferring Attention score calculated by target attention layers
        :return: return Amendatory Replay Buffer
        """
        global j_store
        for policy_id, s in batch.policy_batches.items():
            storage = list(self.pre_replay_buffers[policy_id]._storage)
            index = len(storage) - self.memory_thres - 1
            tmp_buffer = storage.copy()
            current_intersection = self.inter_num_2_id(
                policy_id_handle(policy_id))
            '''
            For comparing the change of rewards 
            '''
            # ------------------------------
            while index > self.index - 1:
                obs = storage[index][0]
                action = storage[index][1]
                reward = storage[index][2]
                new_obs = storage[index][3]
                done = storage[index][4]
                p_value = 0

                all_roads_path_2dlst = np.array(
                    self.config['env_config']['lane_phase_info']
                    [current_intersection]['phase_roadLink_mapping'][action +
                                                                     1])
                all_end_roads = self.config['env_config']['lane_phase_info'][
                    current_intersection]['end_lane']
                permitted_end_roads = np.unique([
                    all_roads_path_2dlst[lane_index, 1] for lane_index,
                    start_lane in enumerate(all_roads_path_2dlst[:, 0])
                    if start_lane[-1] != '2'
                ])
                dis_permitted_end_roads = list(
                    set(all_end_roads).difference(
                        set(list(permitted_end_roads))))

                # Take neighbors into account
                for other_policy_id, s in batch.policy_batches.items():
                    other_intersection = self.inter_num_2_id(
                        policy_id_handle(other_policy_id))
                    if other_policy_id != policy_id and other_intersection in \
                            self.traffic_light_node_dict[current_intersection]['neighbor_ENWS']:
                        other_storage = self.pre_replay_buffers[
                            other_policy_id]._storage
                        '''
                        For corresponding lane in a neighbouring intersection, m_2 represents the waiting count in
                        t+n time step and m_1 for t step.  m_2-m_1/m_1
                        '''
                        road_index_dict = {
                            road: road_index
                            for road_index, road in
                            enumerate(self.config['env_config']['road_sort']
                                      [other_intersection])
                        }

                        # differential = np.max(
                        #     np.array(other_storage[index + self.memory_thres - 1:index + self.memory_thres])[:,
                        #     2]) / other_storage[index][2]
                        for road in road_index_dict.keys():
                            if road in all_end_roads:
                                if road in permitted_end_roads:
                                    I_a = -1
                                elif road in dis_permitted_end_roads:
                                    I_a = 0
                                else:
                                    print('wrong')
                                road_index = road_index_dict[road]

                                m_1 = np.array(
                                    other_storage[index])[0][road_index]
                                m_2 = np.mean([
                                    other_storage[index + self.memory_thres -
                                                  2][0][road_index],
                                    other_storage[index + self.memory_thres -
                                                  1][0][road_index]
                                ])
                                if m_2 - m_1 == 0 or m_1 == 0:
                                    differential = 0
                                else:
                                    differential = m_2 - m_1 / m_1  # m_2 = 0, m_1 != 0 -> differential = -1
                                    if differential > 1:
                                        differential = 0

                                p_value += m_1 * np.tanh(differential) * I_a

                if self.config['env_config']['Gamma_Reward']:
                    p_reward = reward + self.gamma * p_value
                    # print('Reward: ' + str(Reward) + ',' + 'reward: ' + str(reward))
                    if p_reward <= -20:
                        p_reward = -20
                    # print(Reward)
                else:
                    p_reward = reward
                '''
                For compare reward change 
                '''
                # if 50 < j_store < 100:
                #     self.raw_reward_store[self.inter_num_2_id(policy_id_handle(policy_id))].append(reward)
                #     self.Reward_store[self.inter_num_2_id(policy_id_handle(policy_id))].append(Reward)

                # ------------------------------
                tmp_buffer[index] = list(storage[index])
                tmp_buffer[index][2] = p_reward
                index -= 1

            for i in range(self.index, len(tmp_buffer) - self.memory_thres):
                self.replay_buffers[policy_id].add(obs_t=tmp_buffer[i][0],
                                                   action=tmp_buffer[i][1],
                                                   reward=tmp_buffer[i][2],
                                                   obs_tp1=tmp_buffer[i][3],
                                                   done=tmp_buffer[i][4],
                                                   weight=None)

        # Reward MDP
        index = len(storage) - self.memory_thres - 1
        while index > self.index - 1:
            for policy_id, s in batch.policy_batches.items():
                current_intersection = self.inter_num_2_id(
                    policy_id_handle(policy_id))
                storage = list(self.replay_buffers[policy_id]._storage)
                p_reward = storage[index][2]
                sum_other_reward = 0
                for other_policy_id, s in batch.policy_batches.items():
                    other_intersection = self.inter_num_2_id(
                        policy_id_handle(other_policy_id))
                    if other_policy_id != policy_id and other_intersection in \
                            self.traffic_light_node_dict[current_intersection]['neighbor_ENWS']:
                        other_storage = self.replay_buffers[
                            other_policy_id]._storage
                        pre_other_storage = self.pre_replay_buffers[
                            other_policy_id]._storage
                        if index + self.memory_thres >= len(other_storage):
                            sum_other_reward = 0
                        else:
                            sum_other_reward += np.tanh(
                                other_storage[index + self.memory_thres][2] /
                                pre_other_storage[index +
                                                  self.memory_thres][2] -
                                self.punish_coeff)
                Reward = p_reward + self.gamma * sum_other_reward
                self.replay_buffers[policy_id]._storage[index] = list(
                    self.replay_buffers[policy_id]._storage[index])
                self.replay_buffers[policy_id]._storage[index][2] = Reward
                self.replay_buffers[policy_id]._storage[index] = tuple(
                    self.replay_buffers[policy_id]._storage[index])
            index -= 1

        j_store += 1
        self.index = len(storage) - self.memory_thres

        # if j_store == 100:
        #     print("Start recording the reward !!!!!!!!!!!")
        #     raw_reward_store_np = {}
        #     Reward_store_np = {}
        #     for inter_id in self.tmp_dic:
        #         raw_reward_store_np[inter_id] = np.array(self.raw_reward_store[inter_id])
        #         Reward_store_np[inter_id] = np.array(self.Reward_store[inter_id])
        #     raw_reward_store_pd = pd.DataFrame(dict((k, pd.Series(v)) for k, v in raw_reward_store_np.items()))
        #     Reward_store_pd = pd.DataFrame(dict((k, pd.Series(v)) for k, v in Reward_store_np.items()))
        #     raw_reward_store_pd.to_csv(os.path.join(self.record_dir, 'raw_reward_store_pd.csv'))
        #     Reward_store_pd.to_csv(os.path.join(self.record_dir, 'Reward_store_pd.csv'))

        # self.replay_buffers = storage[:self.index]

        self.num_steps_sampled = len(
            self.replay_buffers[policy_id]._storage)  # Any policy_id is OK

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

    def _sigmoid(self, x):
        return 1 / (1 + math.exp(-x))

    def read_traffic_light_node_dict(self):
        path_to_read = os.path.join(self.record_dir,
                                    'traffic_light_node_dict.conf')
        with open(path_to_read, 'r') as f:
            self.traffic_light_node_dict = eval(f.read())
            print("Read traffic_light_node_dict")

    def inter_num_2_id(self, num):
        return list(self.tmp_dic.keys())[list(
            self.tmp_dic.values()).index(num)]
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that policy evaluators return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""

    @override(PolicyOptimizer)
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              prioritized_replay=True,
              prioritized_replay_alpha=0.6,
              prioritized_replay_beta=0.4,
              schedule_max_timesteps=100000,
              beta_annealing_fraction=0.2,
              final_prioritized_replay_beta=0.4,
              prioritized_replay_eps=1e-6,
              train_batch_size=32,
              sample_batch_size=4):

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(
                schedule_max_timesteps * beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(
                    buffer_size, alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        assert buffer_size >= self.replay_starts

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({
                    DEFAULT_POLICY_ID: batch
                }, batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            info_dict = self.local_evaluator.learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                if "stats" in info:
                    self.learner_stats[policy_id] = info["stats"]
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (
                        np.abs(td_error) + self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample(
                         self.train_batch_size,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample(self.train_batch_size)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)
class SyncBatchReplayOptimizer(PolicyOptimizer):
    """Variant of the sync replay optimizer that replays entire batches.

    This enables RNN support. Does not currently support prioritization."""

    @override(PolicyOptimizer)
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              train_batch_size=32):
        self.replay_starts = learning_starts
        self.max_buffer_size = buffer_size
        self.train_batch_size = train_batch_size
        assert self.max_buffer_size >= self.replay_starts

        # List of buffered sample batches
        self.replay_buffer = []
        self.buffer_size = 0

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batches = ray.get(
                    [e.sample.remote() for e in self.remote_evaluators])
            else:
                batches = [self.local_evaluator.sample()]

            # Handle everything as if multiagent
            tmp = []
            for batch in batches:
                if isinstance(batch, SampleBatch):
                    batch = MultiAgentBatch({
                        DEFAULT_POLICY_ID: batch
                    }, batch.count)
                tmp.append(batch)
            batches = tmp

            for batch in batches:
                self.replay_buffer.append(batch)
                self.num_steps_sampled += batch.count
                self.buffer_size += batch.count
                while self.buffer_size > self.max_buffer_size:
                    evicted = self.replay_buffer.pop(0)
                    self.buffer_size -= evicted.count

        if self.num_steps_sampled >= self.replay_starts:
            return self._optimize()
        else:
            return {}

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = [random.choice(self.replay_buffer)]
        while sum(s.count for s in samples) < self.train_batch_size:
            samples.append(random.choice(self.replay_buffer))
        samples = SampleBatch.concat_samples(samples)
        with self.grad_timer:
            info_dict = self.local_evaluator.learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                if "stats" in info:
                    self.learner_stats[policy_id] = info["stats"]
            self.grad_timer.push_units_processed(samples.count)
        self.num_steps_trained += samples.count
        return info_dict
Exemple #26
0
class LocalSyncReplayOptimizer(Optimizer):
    """Variant of the local sync optimizer that supports replay (for DQN)."""

    def _init(
            self, learning_starts=1000, buffer_size=10000,
            prioritized_replay=True, prioritized_replay_alpha=0.6,
            prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6,
            train_batch_size=32, sample_batch_size=4, clip_rewards=True):

        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()

        # Set up replay buffer
        if prioritized_replay:
            self.replay_buffer = PrioritizedReplayBuffer(
                buffer_size, alpha=prioritized_replay_alpha,
                clip_rewards=clip_rewards)
        else:
            self.replay_buffer = ReplayBuffer(buffer_size, clip_rewards)

        assert buffer_size >= self.replay_starts

    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()
            for row in batch.rows():
                self.replay_buffer.add(
                    row["obs"], row["actions"], row["rewards"], row["new_obs"],
                    row["dones"], row["weights"])

        if len(self.replay_buffer) >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    def _optimize(self):
        with self.replay_timer:
            if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
                (obses_t, actions, rewards, obses_tp1,
                    dones, weights, batch_indexes) = self.replay_buffer.sample(
                        self.train_batch_size,
                        beta=self.prioritized_replay_beta)
            else:
                (obses_t, actions, rewards, obses_tp1,
                    dones) = self.replay_buffer.sample(
                        self.train_batch_size)
                weights = np.ones_like(rewards)
                batch_indexes = - np.ones_like(rewards)

            samples = SampleBatch({
                "obs": obses_t, "actions": actions, "rewards": rewards,
                "new_obs": obses_tp1, "dones": dones, "weights": weights,
                "batch_indexes": batch_indexes})

        with self.grad_timer:
            td_error = self.local_evaluator.compute_apply(samples)
            new_priorities = (
                np.abs(td_error) + self.prioritized_replay_eps)
            if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
                self.replay_buffer.update_priorities(
                    samples["batch_indexes"], new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def stats(self):
        return dict(Optimizer.stats(self), **{
            "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
            "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
            "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
            "update_time_ms": round(1000 * self.update_weights_timer.mean, 3),
            "opt_peak_throughput": round(self.grad_timer.mean_throughput, 3),
            "opt_samples": round(self.grad_timer.mean_units_processed, 3),
        })
    def __init__(
        self,
        workers,
        learning_starts=1000,
        buffer_size=10000,
        prioritized_replay=True,
        prioritized_replay_alpha=0.6,
        prioritized_replay_beta=0.4,
        prioritized_replay_eps=1e-6,
        final_prioritized_replay_beta=0.4,
        train_batch_size=32,
        before_learn_on_batch=None,
        synchronize_sampling=False,
        prioritized_replay_beta_annealing_timesteps=100000 * 0.2,
    ):
        """Initialize an sync replay optimizer.

        Args:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            final_prioritized_replay_beta (float): Final value of beta.
            train_batch_size (int): size of batches to learn on
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
            prioritized_replay_beta_annealing_timesteps (int): The timestep at
                which PR-beta annealing should end.
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts

        # Linearly annealing beta used in Rainbow paper, stopping at
        # `final_prioritized_replay_beta`.
        self.prioritized_replay_beta = PiecewiseSchedule(
            endpoints=[(0, prioritized_replay_beta),
                       (prioritized_replay_beta_annealing_timesteps,
                        final_prioritized_replay_beta)],
            outside_value=final_prioritized_replay_beta,
            framework=None)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))
    def _init(self,
              train_batch_size=500,
              sample_batch_size=50,
              num_envs_per_worker=1,
              num_gpus=0,
              lr=0.0005,
              replay_buffer_num_slots=0,
              replay_proportion=0.0,
              num_data_loader_buffers=1,
              max_sample_requests_in_flight_per_worker=2,
              broadcast_interval=1,
              num_sgd_iter=1,
              minibatch_buffer_size=1,
              learner_queue_size=16,
              _fake_gpus=False):
        self.train_batch_size = train_batch_size
        self.sample_batch_size = sample_batch_size
        self.broadcast_interval = broadcast_interval

        self._stats_start_time = time.time()
        self._last_stats_time = {}
        self._last_stats_sum = {}

        if num_gpus > 1 or num_data_loader_buffers > 1:
            logger.info(
                "Enabling multi-GPU mode, {} GPUs, {} parallel loaders".format(
                    num_gpus, num_data_loader_buffers))
            if num_data_loader_buffers < minibatch_buffer_size:
                raise ValueError(
                    "In multi-gpu mode you must have at least as many "
                    "parallel data loader buffers as minibatch buffers: "
                    "{} vs {}".format(num_data_loader_buffers,
                                      minibatch_buffer_size))
            self.learner = TFMultiGPULearner(
                self.local_evaluator,
                lr=lr,
                num_gpus=num_gpus,
                train_batch_size=train_batch_size,
                num_data_loader_buffers=num_data_loader_buffers,
                minibatch_buffer_size=minibatch_buffer_size,
                num_sgd_iter=num_sgd_iter,
                learner_queue_size=learner_queue_size,
                _fake_gpus=_fake_gpus)
        else:
            self.learner = LearnerThread(self.local_evaluator,
                                         minibatch_buffer_size, num_sgd_iter,
                                         learner_queue_size)
        self.learner.start()

        if len(self.remote_evaluators) == 0:
            logger.warning("Config num_workers=0 means training will hang!")

        # Stats
        self._optimizer_step_timer = TimerStat()
        self.num_weight_syncs = 0
        self.num_replayed = 0
        self._stats_start_time = time.time()
        self._last_stats_time = {}
        self._last_stats_val = {}

        # Kick off async background sampling
        self.sample_tasks = TaskPool()
        weights = self.local_evaluator.get_weights()
        for ev in self.remote_evaluators:
            ev.set_weights.remote(weights)
            for _ in range(max_sample_requests_in_flight_per_worker):
                self.sample_tasks.add(ev, ev.sample.remote())

        self.batch_buffer = []

        if replay_proportion:
            if replay_buffer_num_slots * sample_batch_size <= train_batch_size:
                raise ValueError(
                    "Replay buffer size is too small to produce train, "
                    "please increase replay_buffer_num_slots.",
                    replay_buffer_num_slots, sample_batch_size,
                    train_batch_size)
        self.replay_proportion = replay_proportion
        self.replay_buffer_num_slots = replay_buffer_num_slots
        self.replay_batches = []
    def __init__(
        self,
        num_shards: int = 1,
        learning_starts: int = 1000,
        capacity: int = 10000,
        replay_batch_size: int = 1,
        prioritized_replay_alpha: float = 0.6,
        prioritized_replay_beta: float = 0.4,
        prioritized_replay_eps: float = 1e-6,
        replay_mode: str = "independent",
        replay_sequence_length: int = 1,
        replay_burn_in: int = 0,
        replay_zero_init_states: bool = True,
        buffer_size=DEPRECATED_VALUE,
    ):
        """Initializes a MultiAgentReplayBuffer instance.

        Args:
            num_shards: The number of buffer shards that exist in total
                (including this one).
            learning_starts: Number of timesteps after which a call to
                `replay()` will yield samples (before that, `replay()` will
                return None).
            capacity: The capacity of the buffer. Note that when
                `replay_sequence_length` > 1, this is the number of sequences
                (not single timesteps) stored.
            replay_batch_size: The batch size to be sampled (in timesteps).
                Note that if `replay_sequence_length` > 1,
                `self.replay_batch_size` will be set to the number of
                sequences sampled (B).
            prioritized_replay_alpha: Alpha parameter for a prioritized
                replay buffer. Use 0.0 for no prioritization.
            prioritized_replay_beta: Beta parameter for a prioritized
                replay buffer.
            prioritized_replay_eps: Epsilon parameter for a prioritized
                replay buffer.
            replay_mode: One of "independent" or "lockstep". Determined,
                whether in the multiagent case, sampling is done across all
                agents/policies equally.
            replay_sequence_length: The sequence length (T) of a single
                sample. If > 1, we will sample B x T from this buffer.
            replay_burn_in: The burn-in length in case
                `replay_sequence_length` > 0. This is the number of timesteps
                each sequence overlaps with the previous one to generate a
                better internal state (=state after the burn-in), instead of
                starting from 0.0 each RNN rollout.
            replay_zero_init_states: Whether the initial states in the
                buffer (if replay_sequence_length > 0) are alwayas 0.0 or
                should be updated with the previous train_batch state outputs.
        """
        # Deprecated args.
        if buffer_size != DEPRECATED_VALUE:
            deprecation_warning("ReplayBuffer(size)",
                                "ReplayBuffer(capacity)",
                                error=False)
            capacity = buffer_size

        self.replay_starts = learning_starts // num_shards
        self.capacity = capacity // num_shards
        self.replay_batch_size = replay_batch_size
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.replay_mode = replay_mode
        self.replay_sequence_length = replay_sequence_length
        self.replay_burn_in = replay_burn_in
        self.replay_zero_init_states = replay_zero_init_states

        if replay_sequence_length > 1:
            self.replay_batch_size = int(
                max(1, replay_batch_size // replay_sequence_length))
            logger.info(
                "Since replay_sequence_length={} and replay_batch_size={}, "
                "we will replay {} sequences at a time.".format(
                    replay_sequence_length, replay_batch_size,
                    self.replay_batch_size))

        if replay_mode not in ["lockstep", "independent"]:
            raise ValueError("Unsupported replay mode: {}".format(replay_mode))

        def gen_replay():
            while True:
                yield self.replay()

        ParallelIteratorWorker.__init__(self, gen_replay, False)

        def new_buffer():
            if prioritized_replay_alpha == 0.0:
                return ReplayBuffer(self.capacity)
            else:
                return PrioritizedReplayBuffer(self.capacity,
                                               alpha=prioritized_replay_alpha)

        self.replay_buffers = collections.defaultdict(new_buffer)

        # Metrics.
        self.add_batch_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.update_priorities_timer = TimerStat()
        self.num_added = 0

        # Make externally accessible for testing.
        global _local_replay_buffer
        _local_replay_buffer = self
        # If set, return this instead of the usual data for testing.
        self._fake_batch = None
class SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    evaluators, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote evaluators.
    """
    @override(PolicyOptimizer)
    def _init(self, num_sgd_iter=1, train_batch_size=1):
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.train_batch_size = train_batch_size
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray.get([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.learn_on_batch(samples)
                if "stats" in fetches:
                    self.learner_stats = fetches["stats"]
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms":
                round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms":
                round(1000 * self.grad_timer.mean, 3),
                "update_time_ms":
                round(1000 * self.update_weights_timer.mean, 3),
                "opt_peak_throughput":
                round(self.grad_timer.mean_throughput, 3),
                "sample_peak_throughput":
                round(self.sample_timer.mean_throughput, 3),
                "opt_samples":
                round(self.grad_timer.mean_units_processed, 3),
                "learner":
                self.learner_stats,
            })
    def __init__(
        self,
        capacity: int = 10000,
        storage_unit: str = "timesteps",
        num_shards: int = 1,
        replay_batch_size: int = 1,
        learning_starts: int = 1000,
        replay_mode: str = "independent",
        replay_sequence_length: int = 1,
        replay_burn_in: int = 0,
        replay_zero_init_states: bool = True,
        prioritized_replay_alpha: float = 0.6,
        prioritized_replay_beta: float = 0.4,
        prioritized_replay_eps: float = 1e-6,
        underlying_buffer_config: dict = None,
        **kwargs
    ):
        """Initializes a MultiAgentReplayBuffer instance.

        Args:
            num_shards: The number of buffer shards that exist in total
                (including this one).
            storage_unit: Either 'timesteps', 'sequences' or
                'episodes'. Specifies how experiences are stored. If they
                are stored in episodes, replay_sequence_length is ignored.
                If they are stored in episodes, replay_sequence_length is
                ignored.
            learning_starts: Number of timesteps after which a call to
                `replay()` will yield samples (before that, `replay()` will
                return None).
            capacity: The capacity of the buffer. Note that when
                `replay_sequence_length` > 1, this is the number of sequences
                (not single timesteps) stored.
            replay_batch_size: The batch size to be sampled (in timesteps).
                Note that if `replay_sequence_length` > 1,
                `self.replay_batch_size` will be set to the number of
                sequences sampled (B).
            prioritized_replay_alpha: Alpha parameter for a prioritized
                replay buffer. Use 0.0 for no prioritization.
            prioritized_replay_beta: Beta parameter for a prioritized
                replay buffer.
            prioritized_replay_eps: Epsilon parameter for a prioritized
                replay buffer.
            replay_sequence_length: The sequence length (T) of a single
                sample. If > 1, we will sample B x T from this buffer.
            replay_burn_in: The burn-in length in case
                `replay_sequence_length` > 0. This is the number of timesteps
                each sequence overlaps with the previous one to generate a
                better internal state (=state after the burn-in), instead of
                starting from 0.0 each RNN rollout.
            replay_zero_init_states: Whether the initial states in the
                buffer (if replay_sequence_length > 0) are alwayas 0.0 or
                should be updated with the previous train_batch state outputs.
            underlying_buffer_config: A config that contains all necessary
                constructor arguments and arguments for methods to call on
                the underlying buffers. This replaces the standard behaviour
                of the underlying PrioritizedReplayBuffer. The config
                follows the conventions of the general
                replay_buffer_config. kwargs for subsequent calls of methods
                may also be included. Example:
                "replay_buffer_config": {"type": PrioritizedReplayBuffer,
                "capacity": 10, "storage_unit": "timesteps",
                prioritized_replay_alpha: 0.5, prioritized_replay_beta: 0.5,
                prioritized_replay_eps: 0.5}
            **kwargs: Forward compatibility kwargs.
        """
        if "replay_mode" in kwargs and (
            kwargs["replay_mode"] == "lockstep"
            or kwargs["replay_mode"] == ReplayMode.LOCKSTEP
        ):
            if log_once("lockstep_mode_not_supported"):
                logger.error(
                    "Replay mode `lockstep` is not supported for "
                    "MultiAgentPrioritizedReplayBuffer. "
                    "This buffer will run in `independent` mode."
                )
            kwargs["replay_mode"] = "independent"

        if underlying_buffer_config is not None:
            if log_once("underlying_buffer_config_not_supported"):
                logger.info(
                    "PrioritizedMultiAgentReplayBuffer instantiated "
                    "with underlying_buffer_config. This will "
                    "overwrite the standard behaviour of the "
                    "underlying PrioritizedReplayBuffer."
                )
            prioritized_replay_buffer_config = underlying_buffer_config
        else:
            prioritized_replay_buffer_config = {
                "type": PrioritizedReplayBuffer,
                "alpha": prioritized_replay_alpha,
                "beta": prioritized_replay_beta,
            }

        shard_capacity = capacity // num_shards
        MultiAgentReplayBuffer.__init__(
            self,
            shard_capacity,
            storage_unit,
            **kwargs,
            underlying_buffer_config=prioritized_replay_buffer_config,
            replay_batch_size=replay_batch_size,
            learning_starts=learning_starts,
            replay_mode=replay_mode,
            replay_sequence_length=replay_sequence_length,
            replay_burn_in=replay_burn_in,
            replay_zero_init_states=replay_zero_init_states,
        )

        self.prioritized_replay_eps = prioritized_replay_eps
        self.update_priorities_timer = TimerStat()
Exemple #32
0
 def _init(self, batch_size=32):
     self.update_weights_timer = TimerStat()
     self.sample_timer = TimerStat()
     self.grad_timer = TimerStat()
     self.throughput = RunningStat()
     self.batch_size = batch_size
Exemple #33
0
class SyncBatchReplayOptimizer(PolicyOptimizer):
    """Variant of the sync replay optimizer that replays entire batches.

    This enables RNN support. Does not currently support prioritization."""
    @override(PolicyOptimizer)
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              train_batch_size=32):
        self.replay_starts = learning_starts
        self.max_buffer_size = buffer_size
        self.train_batch_size = train_batch_size
        assert self.max_buffer_size >= self.replay_starts

        # List of buffered sample batches
        self.replay_buffer = []
        self.buffer_size = 0

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batches = ray.get(
                    [e.sample.remote() for e in self.remote_evaluators])
            else:
                batches = [self.local_evaluator.sample()]

            # Handle everything as if multiagent
            tmp = []
            for batch in batches:
                if isinstance(batch, SampleBatch):
                    batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                            batch.count)
                tmp.append(batch)
            batches = tmp

            for batch in batches:
                self.replay_buffer.append(batch)
                self.num_steps_sampled += batch.count
                self.buffer_size += batch.count
                while self.buffer_size > self.max_buffer_size:
                    evicted = self.replay_buffer.pop(0)
                    self.buffer_size -= evicted.count

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = [random.choice(self.replay_buffer)]
        while sum(s.count for s in samples) < self.train_batch_size:
            samples.append(random.choice(self.replay_buffer))
        samples = SampleBatch.concat_samples(samples)
        with self.grad_timer:
            info_dict = self.local_evaluator.compute_apply(samples)
            for policy_id, info in info_dict.items():
                if "stats" in info:
                    self.learner_stats[policy_id] = info["stats"]
            self.grad_timer.push_units_processed(samples.count)
        self.num_steps_trained += samples.count
Exemple #34
0
    def __init__(self,
                 workers,
                 sgd_batch_size=128,
                 num_sgd_iter=10,
                 rollout_fragment_length=200,
                 num_envs_per_worker=1,
                 train_batch_size=1024,
                 num_gpus=0,
                 standardize_fields=[],
                 shuffle_sequences=True,
                 _fake_gpus=False):
        """Initialize a synchronous multi-gpu optimizer.

        Arguments:
            workers (WorkerSet): all workers
            sgd_batch_size (int): SGD minibatch size within train batch size
            num_sgd_iter (int): number of passes to learn on per train batch
            rollout_fragment_length (int): size of batches to sample from
                workers.
            num_envs_per_worker (int): num envs in each rollout worker
            train_batch_size (int): size of batches to learn on
            num_gpus (int): number of GPUs to use for data-parallel SGD
            standardize_fields (list): list of fields in the training batch
                to normalize
            shuffle_sequences (bool): whether to shuffle the train batch prior
                to SGD to break up correlations
            _fake_gpus (bool): Whether to use fake-GPUs (CPUs) instead of
                actual GPUs (should only be used for testing on non-GPU
                machines).
        """
        PolicyOptimizer.__init__(self, workers)

        self.batch_size = sgd_batch_size
        self.num_sgd_iter = num_sgd_iter
        self.num_envs_per_worker = num_envs_per_worker
        self.rollout_fragment_length = rollout_fragment_length
        self.train_batch_size = train_batch_size
        self.shuffle_sequences = shuffle_sequences

        # Collect actual devices to use.
        if not num_gpus:
            _fake_gpus = True
            num_gpus = 1
        type_ = "cpu" if _fake_gpus else "gpu"
        self.devices = [
            "/{}:{}".format(type_, i) for i in range(int(math.ceil(num_gpus)))
        ]

        self.batch_size = int(sgd_batch_size / len(self.devices)) * len(
            self.devices)
        assert self.batch_size % len(self.devices) == 0
        assert self.batch_size >= len(self.devices), "batch size too small"
        self.per_device_batch_size = int(self.batch_size / len(self.devices))
        self.sample_timer = TimerStat()
        self.load_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.update_weights_timer = TimerStat()
        self.standardize_fields = standardize_fields

        logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices))

        self.policies = dict(self.workers.local_worker()
                             .foreach_trainable_policy(lambda p, i: (i, p)))
        logger.debug("Policies to train: {}".format(self.policies))
        for policy_id, policy in self.policies.items():
            if not isinstance(policy, TFPolicy):
                raise ValueError(
                    "Only TF graph policies are supported with multi-GPU. "
                    "Try setting `simple_optimizer=True` instead.")

        # per-GPU graph copies created below must share vars with the policy
        # reuse is set to AUTO_REUSE because Adam nodes are created after
        # all of the device copies are created.
        self.optimizers = {}
        with self.workers.local_worker().tf_sess.graph.as_default():
            with self.workers.local_worker().tf_sess.as_default():
                for policy_id, policy in self.policies.items():
                    with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE):
                        if policy._state_inputs:
                            rnn_inputs = policy._state_inputs + [
                                policy._seq_lens
                            ]
                        else:
                            rnn_inputs = []
                        self.optimizers[policy_id] = (
                            LocalSyncParallelOptimizer(
                                policy._optimizer, self.devices,
                                [v
                                 for _, v in policy._loss_inputs], rnn_inputs,
                                self.per_device_batch_size, policy.copy))

                self.sess = self.workers.local_worker().tf_sess
                self.sess.run(tf.global_variables_initializer())
Exemple #35
0
 def _init(self, num_sgd_iter=1):
     self.update_weights_timer = TimerStat()
     self.sample_timer = TimerStat()
     self.grad_timer = TimerStat()
     self.throughput = RunningStat()
     self.num_sgd_iter = num_sgd_iter
    def __init__(self,
                 workers,
                 sgd_batch_size=128,
                 num_sgd_iter=10,
                 sample_batch_size=200,
                 num_envs_per_worker=1,
                 train_batch_size=1024,
                 num_gpus=0,
                 standardize_fields=[],
                 straggler_mitigation=False):
        PolicyOptimizer.__init__(self, workers)

        self.batch_size = sgd_batch_size
        self.num_sgd_iter = num_sgd_iter
        self.num_envs_per_worker = num_envs_per_worker
        self.sample_batch_size = sample_batch_size
        self.train_batch_size = train_batch_size
        self.straggler_mitigation = straggler_mitigation
        if not num_gpus:
            self.devices = ["/cpu:0"]
        else:
            self.devices = [
                "/gpu:{}".format(i) for i in range(int(math.ceil(num_gpus)))
            ]
        self.batch_size = int(sgd_batch_size / len(self.devices)) * len(
            self.devices)
        assert self.batch_size % len(self.devices) == 0
        assert self.batch_size >= len(self.devices), "batch size too small"
        self.per_device_batch_size = int(self.batch_size / len(self.devices))
        self.sample_timer = TimerStat()
        self.load_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.update_weights_timer = TimerStat()
        self.standardize_fields = standardize_fields

        logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices))

        self.policies = dict(
            self.workers.local_worker().foreach_trainable_policy(lambda p, i:
                                                                 (i, p)))
        logger.debug("Policies to train: {}".format(self.policies))
        for policy_id, policy in self.policies.items():
            if not isinstance(policy, TFPolicy):
                raise ValueError(
                    "Only TF policies are supported with multi-GPU. Try using "
                    "the simple optimizer instead.")

        # per-GPU graph copies created below must share vars with the policy
        # reuse is set to AUTO_REUSE because Adam nodes are created after
        # all of the device copies are created.
        self.optimizers = {}
        with self.workers.local_worker().tf_sess.graph.as_default():
            with self.workers.local_worker().tf_sess.as_default():
                for policy_id, policy in self.policies.items():
                    with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE):
                        if policy._state_inputs:
                            rnn_inputs = policy._state_inputs + [
                                policy._seq_lens
                            ]
                        else:
                            rnn_inputs = []
                        self.optimizers[policy_id] = (
                            LocalSyncParallelOptimizer(
                                policy._optimizer, self.devices,
                                [v
                                 for _, v in policy._loss_inputs], rnn_inputs,
                                self.per_device_batch_size, policy.copy))

                self.sess = self.workers.local_worker().tf_sess
                self.sess.run(tf.global_variables_initializer())
class SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    evaluators, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote evaluators.
    """

    @override(PolicyOptimizer)
    def _init(self, num_sgd_iter=1, train_batch_size=1):
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.train_batch_size = train_batch_size
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray.get([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.compute_apply(samples)
                if "stats" in fetches:
                    self.learner_stats = fetches["stats"]
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "sample_peak_throughput": round(
                    self.sample_timer.mean_throughput, 3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })
Exemple #38
0
 def _init(self):
     self.apply_timer = TimerStat()
     self.wait_timer = TimerStat()
     self.dispatch_timer = TimerStat()
     self.grads_per_step = self.config.get("grads_per_step", 100)
    def __init__(self,
                 workers,
                 learning_starts=1000,
                 buffer_size=10000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 train_batch_size=512,
                 rollout_fragment_length=50,
                 num_replay_buffer_shards=1,
                 max_weight_sync_delay=400,
                 debug=False,
                 batch_replay=False):
        """Initialize an async replay optimizer.

        Arguments:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            train_batch_size (int): size of batches to learn on
            rollout_fragment_length (int): size of batches to sample from
                workers.
            num_replay_buffer_shards (int): number of actors to use to store
                replay samples
            max_weight_sync_delay (int): update the weights of a rollout worker
                after collecting this number of timesteps from it
            debug (bool): return extra debug stats
            batch_replay (bool): replay entire sequential batches of
                experiences instead of sampling steps individually
        """
        PolicyOptimizer.__init__(self, workers)

        self.debug = debug
        self.batch_replay = batch_replay
        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.max_weight_sync_delay = max_weight_sync_delay

        self.learner = LearnerThread(self.workers.local_worker())
        self.learner.start()

        if self.batch_replay:
            replay_cls = BatchReplayActor
        else:
            replay_cls = ReplayActor
        self.replay_actors = create_colocated(replay_cls, [
            num_replay_buffer_shards,
            learning_starts,
            buffer_size,
            train_batch_size,
            prioritized_replay_alpha,
            prioritized_replay_beta,
            prioritized_replay_eps,
        ], num_replay_buffer_shards)

        # Stats
        self.timers = {
            k: TimerStat()
            for k in [
                "put_weights", "get_samples", "sample_processing",
                "replay_processing", "update_priorities", "train", "sample"
            ]
        }
        self.num_weight_syncs = 0
        self.num_samples_dropped = 0
        self.learning_started = False

        # Number of worker steps since the last weight update
        self.steps_since_update = {}

        # Otherwise kick of replay tasks for local gradient updates
        self.replay_tasks = TaskPool()
        for ra in self.replay_actors:
            for _ in range(REPLAY_QUEUE_DEPTH):
                self.replay_tasks.add(ra, ra.replay.remote())

        # Kick off async background sampling
        self.sample_tasks = TaskPool()
        if self.workers.remote_workers():
            self._set_workers(self.workers.remote_workers())
Exemple #40
0
class LearnerThread(threading.Thread):
    """Background thread that updates the local model from replay data.

    The learner thread communicates with the main thread through Queues. This
    is needed since Ray operations can only be run on the main thread. In
    addition, moving heavyweight gradient ops session runs off the main thread
    improves overall throughput.
    """
    def __init__(self, local_worker):
        threading.Thread.__init__(self)
        self.learner_queue_size = WindowStat("size", 50)
        self.local_worker = local_worker
        self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
        self.outqueue = queue.Queue()
        self.queue_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.overall_timer = TimerStat()
        self.daemon = True
        self.weights_updated = False
        self.stopped = False
        self.learner_info = {}

    def run(self):
        # Switch on eager mode if configured.
        if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]:
            tf1.enable_eager_execution()
        while not self.stopped:
            self.step()

    def step(self):
        with self.overall_timer:
            with self.queue_timer:
                ra, replay = self.inqueue.get()
            if replay is not None:
                prio_dict = {}
                with self.grad_timer:
                    # Use LearnerInfoBuilder as a unified way to build the
                    # final results dict from `learn_on_loaded_batch` call(s).
                    # This makes sure results dicts always have the same
                    # structure no matter the setup (multi-GPU, multi-agent,
                    # minibatch SGD, tf vs torch).
                    learner_info_builder = LearnerInfoBuilder(num_devices=1)
                    multi_agent_results = self.local_worker.learn_on_batch(
                        replay)
                    for pid, results in multi_agent_results.items():
                        learner_info_builder.add_learn_on_batch_results(
                            results, pid)
                        td_error = results["td_error"]
                        # Switch off auto-conversion from numpy to torch/tf
                        # tensors for the indices. This may lead to errors
                        # when sent to the buffer for processing
                        # (may get manipulated if they are part of a tensor).
                        replay.policy_batches[pid].set_get_interceptor(None)
                        prio_dict[pid] = (
                            replay.policy_batches[pid].get("batch_indexes"),
                            td_error)
                    self.learner_info = learner_info_builder.finalize()
                    self.grad_timer.push_units_processed(replay.count)
                self.outqueue.put((ra, prio_dict, replay.count))
            self.learner_queue_size.push(self.inqueue.qsize())
            self.weights_updated = True
            self.overall_timer.push_units_processed(replay and replay.count
                                                    or 0)