Пример #1
0
class SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    evaluators, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote evaluators.
    """

    def _init(self, num_sgd_iter=1, timesteps_per_batch=1):
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.timesteps_per_batch = timesteps_per_batch

    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.timesteps_per_batch:
                if self.remote_evaluators:
                    samples.extend(
                        ray.get([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.compute_apply(samples)
                if self.num_sgd_iter > 1:
                    print(i, fetches)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches

    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "sample_peak_throughput": round(
                    self.sample_timer.mean_throughput, 3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
            })
Пример #2
0
class LearnerThread(threading.Thread):
    """Background thread that updates the local model from replay data.

    The learner thread communicates with the main thread through Queues. This
    is needed since Ray operations can only be run on the main thread. In
    addition, moving heavyweight gradient ops session runs off the main thread
    improves overall throughput.
    """

    def __init__(self, local_worker):
        threading.Thread.__init__(self)
        self.learner_queue_size = WindowStat("size", 50)
        self.local_worker = local_worker
        self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
        self.outqueue = queue.Queue()
        self.queue_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.overall_timer = TimerStat()
        self.daemon = True
        self.weights_updated = False
        self.stopped = False
        self.stats = {}

    def run(self):
        # Switch on eager mode if configured.
        if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]:
            tf1.enable_eager_execution()
        while not self.stopped:
            self.step()

    def step(self):
        with self.overall_timer:
            with self.queue_timer:
                ra, replay = self.inqueue.get()
            if replay is not None:
                prio_dict = {}
                with self.grad_timer:
                    grad_out = self.local_worker.learn_on_batch(replay)
                    for pid, info in grad_out.items():
                        td_error = info.get(
                            "td_error",
                            info[LEARNER_STATS_KEY].get("td_error"))
                        # Switch off auto-conversion from numpy to torch/tf
                        # tensors for the indices. This may lead to errors
                        # when sent to the buffer for processing
                        # (may get manipulated if they are part of a tensor).
                        replay.policy_batches[pid].set_get_interceptor(None)
                        prio_dict[pid] = (
                            replay.policy_batches[pid].get("batch_indexes"),
                            td_error)
                        self.stats[pid] = get_learner_stats(info)
                    self.grad_timer.push_units_processed(replay.count)
                self.outqueue.put((ra, prio_dict, replay.count))
            self.learner_queue_size.push(self.inqueue.qsize())
            self.weights_updated = True
            self.overall_timer.push_units_processed(replay and replay.count
                                                    or 0)
Пример #3
0
class LearnerThread(threading.Thread):
    """Background thread that updates the local model from replay data.

    The learner thread communicates with the main thread through Queues. This
    is needed since Ray operations can only be run on the main thread. In
    addition, moving heavyweight gradient ops session runs off the main thread
    improves overall throughput.
    """

    def __init__(self, local_worker):
        threading.Thread.__init__(self)
        self.learner_queue_size = WindowStat("size", 50)
        self.local_worker = local_worker
        self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
        self.outqueue = queue.Queue()
        self.queue_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.overall_timer = TimerStat()
        self.daemon = True
        self.weights_updated = False
        self.stopped = False
        self.stats = {}

    def run(self):
        while not self.stopped:
            self.step()

    def step(self):
        with self.overall_timer:
            with self.queue_timer:
                ra, replay = self.inqueue.get()
            if replay is not None:
                prio_dict = {}
                with self.grad_timer:
                    grad_out = self.local_worker.learn_on_batch(replay)
                    for pid, info in grad_out.items():
                        td_error = info.get(
                            "td_error",
                            info[LEARNER_STATS_KEY].get("td_error"))
                        prio_dict[pid] = (replay.policy_batches[pid].data.get(
                            "batch_indexes"), td_error)
                        self.stats[pid] = get_learner_stats(info)
                    self.grad_timer.push_units_processed(replay.count)
                self.outqueue.put((ra, prio_dict, replay.count))
            self.learner_queue_size.push(self.inqueue.qsize())
            self.weights_updated = True
            self.overall_timer.push_units_processed(replay and replay.count
                                                    or 0)
Пример #4
0
class SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    evaluators, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote evaluators.
    """
    def _init(self, batch_size=32):
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.batch_size = batch_size

    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                samples = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                samples = self.local_evaluator.sample()

        with self.grad_timer:
            grad, _ = self.local_evaluator.compute_gradients(samples)
            self.local_evaluator.apply_gradients(grad)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count

    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
            })
Пример #5
0
class LocalSyncOptimizer(Optimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    evaluators, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote evaluators.
    """

    def _init(self, batch_size=32):
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.batch_size = batch_size

    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                samples = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                samples = self.local_evaluator.sample()

        with self.grad_timer:
            grad = self.local_evaluator.compute_gradients(samples)
            self.local_evaluator.apply_gradients(grad)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count

    def stats(self):
        return dict(Optimizer.stats(self), **{
            "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
            "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
            "update_time_ms": round(1000 * self.update_weights_timer.mean, 3),
            "opt_peak_throughput": round(self.grad_timer.mean_throughput, 3),
            "opt_samples": round(self.grad_timer.mean_units_processed, 3),
        })
Пример #6
0
class SyncBatchReplayOptimizer(PolicyOptimizer):
    """Variant of the sync replay optimizer that replays entire batches.

    This enables RNN support. Does not currently support prioritization."""
    def __init__(self,
                 workers,
                 learning_starts=1000,
                 buffer_size=10000,
                 train_batch_size=32):
        """Initialize a batch replay optimizer.

        Arguments:
            workers (WorkerSet): set of all workers
            learning_starts (int): start learning after this number of
                timesteps have been collected
            buffer_size (int): max timesteps to keep in the replay buffer
            train_batch_size (int): number of timesteps to train on at once
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts
        self.max_buffer_size = buffer_size
        self.train_batch_size = train_batch_size
        assert self.max_buffer_size >= self.replay_starts

        # List of buffered sample batches
        self.replay_buffer = []
        self.buffer_size = 0

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batches = ray_get_and_free(
                    [e.sample.remote() for e in self.workers.remote_workers()])
            else:
                batches = [self.workers.local_worker().sample()]

            # Handle everything as if multiagent
            tmp = []
            for batch in batches:
                if isinstance(batch, SampleBatch):
                    batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                            batch.count)
                tmp.append(batch)
            batches = tmp

            for batch in batches:
                if batch.count > self.max_buffer_size:
                    raise ValueError(
                        "The size of a single sample batch exceeds the replay "
                        "buffer size ({} > {})".format(batch.count,
                                                       self.max_buffer_size))
                self.replay_buffer.append(batch)
                self.num_steps_sampled += batch.count
                self.buffer_size += batch.count
                while self.buffer_size > self.max_buffer_size:
                    evicted = self.replay_buffer.pop(0)
                    self.buffer_size -= evicted.count

        if self.num_steps_sampled >= self.replay_starts:
            return self._optimize()
        else:
            return {}

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = [random.choice(self.replay_buffer)]
        while sum(s.count for s in samples) < self.train_batch_size:
            samples.append(random.choice(self.replay_buffer))
        samples = SampleBatch.concat_samples(samples)
        with self.grad_timer:
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
            self.grad_timer.push_units_processed(samples.count)
        self.num_steps_trained += samples.count
        return info_dict
Пример #7
0
class LearnerThread(threading.Thread):
    """Background thread that updates the local model from replay data.

    The learner thread communicates with the main thread through Queues. This
    is needed since Ray operations can only be run on the main thread. In
    addition, moving heavyweight gradient ops session runs off the main thread
    improves overall throughput.
    """
    def __init__(self, local_worker):
        threading.Thread.__init__(self)
        self.learner_queue_size = WindowStat("size", 50)
        self.local_worker = local_worker
        self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
        self.outqueue = queue.Queue()
        self.queue_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.overall_timer = TimerStat()
        self.daemon = True
        self.weights_updated = False
        self.stopped = False
        self.learner_info = {}

    def run(self):
        # Switch on eager mode if configured.
        if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]:
            tf1.enable_eager_execution()
        while not self.stopped:
            self.step()

    def step(self):
        with self.overall_timer:
            with self.queue_timer:
                ra, replay = self.inqueue.get()
            if replay is not None:
                prio_dict = {}
                with self.grad_timer:
                    # Use LearnerInfoBuilder as a unified way to build the
                    # final results dict from `learn_on_loaded_batch` call(s).
                    # This makes sure results dicts always have the same
                    # structure no matter the setup (multi-GPU, multi-agent,
                    # minibatch SGD, tf vs torch).
                    learner_info_builder = LearnerInfoBuilder(num_devices=1)
                    multi_agent_results = self.local_worker.learn_on_batch(
                        replay)
                    for pid, results in multi_agent_results.items():
                        learner_info_builder.add_learn_on_batch_results(
                            results, pid)
                        td_error = results["td_error"]
                        # Switch off auto-conversion from numpy to torch/tf
                        # tensors for the indices. This may lead to errors
                        # when sent to the buffer for processing
                        # (may get manipulated if they are part of a tensor).
                        replay.policy_batches[pid].set_get_interceptor(None)
                        prio_dict[pid] = (
                            replay.policy_batches[pid].get("batch_indexes"),
                            td_error)
                    self.learner_info = learner_info_builder.finalize()
                    self.grad_timer.push_units_processed(replay.count)
                self.outqueue.put((ra, prio_dict, replay.count))
            self.learner_queue_size.push(self.inqueue.qsize())
            self.weights_updated = True
            self.overall_timer.push_units_processed(replay and replay.count
                                                    or 0)
Пример #8
0
class MicrobatchOptimizer(PolicyOptimizer):
    """A microbatching synchronous RL optimizer.

    This optimizer pulls sample batches from workers until the target
    microbatch size is reached. Then, it computes and accumulates the policy
    gradient in a local buffer. This process is repeated until the number of
    samples collected equals the train batch size. Then, an accumulated
    gradient update is made.

    This allows for training with effective batch sizes much larger than can
    fit in GPU or host memory.
    """
    def __init__(self, workers, train_batch_size=10000, microbatch_size=1000):
        PolicyOptimizer.__init__(self, workers)

        if train_batch_size <= microbatch_size:
            raise ValueError(
                "The microbatch size must be smaller than the train batch "
                "size, got {} vs {}".format(microbatch_size, train_batch_size))

        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.train_batch_size = train_batch_size
        self.microbatch_size = microbatch_size
        self.learner_stats = {}
        self.policies = dict(
            self.workers.local_worker().foreach_trainable_policy(lambda p, i:
                                                                 (i, p)))
        logger.debug("Policies to train: {}".format(self.policies))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        fetches = {}
        accumulated_gradients = {}
        samples_so_far = 0

        # Accumulate minibatches.
        i = 0
        while samples_so_far < self.train_batch_size:
            i += 1
            with self.sample_timer:
                samples = []
                while sum(s.count for s in samples) < self.microbatch_size:
                    if self.workers.remote_workers():
                        samples.extend(
                            ray.get([
                                e.sample.remote()
                                for e in self.workers.remote_workers()
                            ]))
                    else:
                        samples.append(self.workers.local_worker().sample())
                samples = SampleBatch.concat_samples(samples)
                self.sample_timer.push_units_processed(samples.count)
                samples_so_far += samples.count

            logger.info(
                "Computing gradients for microbatch {} ({}/{} samples)".format(
                    i, samples_so_far, self.train_batch_size))

            # Handle everything as if multiagent
            if isinstance(samples, SampleBatch):
                samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples},
                                          samples.count)

            with self.grad_timer:
                for policy_id, policy in self.policies.items():
                    if policy_id not in samples.policy_batches:
                        continue
                    batch = samples.policy_batches[policy_id]
                    grad_out, info_out = (
                        self.workers.local_worker().compute_gradients(
                            MultiAgentBatch({policy_id: batch}, batch.count)))
                    grad = grad_out[policy_id]
                    fetches.update(info_out)
                    if policy_id not in accumulated_gradients:
                        accumulated_gradients[policy_id] = grad
                    else:
                        grad_size = len(accumulated_gradients[policy_id])
                        assert grad_size == len(grad), (grad_size, len(grad))
                        c = []
                        for a, b in zip(accumulated_gradients[policy_id],
                                        grad):
                            c.append(a + b)
                        accumulated_gradients[policy_id] = c
            self.grad_timer.push_units_processed(samples.count)

        # Apply the accumulated gradient
        logger.info("Applying accumulated gradients ({} samples)".format(
            samples_so_far))
        self.workers.local_worker().apply_gradients(accumulated_gradients)

        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches
        self.num_steps_sampled += samples_so_far
        self.num_steps_trained += samples_so_far
        return self.learner_stats

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms":
                round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms":
                round(1000 * self.grad_timer.mean, 3),
                "update_time_ms":
                round(1000 * self.update_weights_timer.mean, 3),
                "opt_peak_throughput":
                round(self.grad_timer.mean_throughput, 3),
                "sample_peak_throughput":
                round(self.sample_timer.mean_throughput, 3),
                "opt_samples":
                round(self.grad_timer.mean_units_processed, 3),
                "learner":
                self.learner_stats,
            })
Пример #9
0
class SacSyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(self,
                 workers,
                 learning_starts=1000,
                 buffer_size=10000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 schedule_max_timesteps=100000,
                 beta_annealing_fraction=0.2,
                 final_prioritized_replay_beta=0.4,
                 train_batch_size=32,
                 sample_batch_size=4,
                 before_learn_on_batch=None,
                 synchronize_sampling=False,
                 workers_only_sync_policy_list=None):
        """Initialize an sync replay optimizer.

        Arguments:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            schedule_max_timesteps (int): number of timesteps in the schedule
            beta_annealing_fraction (float): fraction of schedule to anneal
                beta over
            final_prioritized_replay_beta (float): final value of beta
            train_batch_size (int): size of batches to learn on
            sample_batch_size (int): size of batches to sample from workers
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(schedule_max_timesteps *
                                   beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        self.workers_only_sync_policy_list = workers_only_sync_policy_list
        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                # !!!!! CHANGED FROM ORIGINAL !!!! doesnt sync policies we arent training
                weights = ray.put(self.workers.local_worker().get_weights(
                    policies=self.workers_only_sync_policy_list))
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batch = SampleBatch.concat_samples(
                    ray_get_and_free([
                        e.sample.remote()
                        for e in self.workers.remote_workers()
                    ]))
            else:
                batch = self.workers.local_worker().sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        idxes = None
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)
Пример #10
0
class SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    workers, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote workers.
    """

    def __init__(self,
                 workers,
                 num_sgd_iter=1,
                 train_batch_size=1,
                 sgd_minibatch_size=0):
        PolicyOptimizer.__init__(self, workers)

        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.sgd_minibatch_size = sgd_minibatch_size
        self.train_batch_size = train_batch_size
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                for minibatch in self._minibatches(samples):
                    fetches = self.workers.local_worker().learn_on_batch(
                        minibatch)
                self.learner_stats = get_learner_stats(fetches)
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "sample_peak_throughput": round(
                    self.sample_timer.mean_throughput, 3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _minibatches(self, samples):
        if not self.sgd_minibatch_size:
            yield samples
            return

        if isinstance(samples, MultiAgentBatch):
            raise NotImplementedError(
                "Minibatching not implemented for multi-agent in simple mode")

        if "state_in_0" in samples.data:
            logger.warn("Not shuffling RNN data for SGD in simple mode")
        else:
            samples.shuffle()

        i = 0
        slices = []
        while i < samples.count:
            slices.append((i, i + self.sgd_minibatch_size))
            i += self.sgd_minibatch_size
        random.shuffle(slices)

        for i, j in slices:
            yield samples.slice(i, j)
Пример #11
0
class LocalSyncReplayOptimizer(Optimizer):
    """Variant of the local sync optimizer that supports replay (for DQN)."""

    def _init(
            self, learning_starts=1000, buffer_size=10000,
            prioritized_replay=True, prioritized_replay_alpha=0.6,
            prioritized_replay_beta=0.4, prioritized_replay_eps=1e-6,
            train_batch_size=32, sample_batch_size=4, clip_rewards=True):

        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()

        # Set up replay buffer
        if prioritized_replay:
            self.replay_buffer = PrioritizedReplayBuffer(
                buffer_size, alpha=prioritized_replay_alpha,
                clip_rewards=clip_rewards)
        else:
            self.replay_buffer = ReplayBuffer(buffer_size, clip_rewards)

        assert buffer_size >= self.replay_starts

    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()
            for row in batch.rows():
                self.replay_buffer.add(
                    row["obs"], row["actions"], row["rewards"], row["new_obs"],
                    row["dones"], row["weights"])

        if len(self.replay_buffer) >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    def _optimize(self):
        with self.replay_timer:
            if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
                (obses_t, actions, rewards, obses_tp1,
                    dones, weights, batch_indexes) = self.replay_buffer.sample(
                        self.train_batch_size,
                        beta=self.prioritized_replay_beta)
            else:
                (obses_t, actions, rewards, obses_tp1,
                    dones) = self.replay_buffer.sample(
                        self.train_batch_size)
                weights = np.ones_like(rewards)
                batch_indexes = - np.ones_like(rewards)

            samples = SampleBatch({
                "obs": obses_t, "actions": actions, "rewards": rewards,
                "new_obs": obses_tp1, "dones": dones, "weights": weights,
                "batch_indexes": batch_indexes})

        with self.grad_timer:
            td_error = self.local_evaluator.compute_apply(samples)
            new_priorities = (
                np.abs(td_error) + self.prioritized_replay_eps)
            if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
                self.replay_buffer.update_priorities(
                    samples["batch_indexes"], new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def stats(self):
        return dict(Optimizer.stats(self), **{
            "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
            "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
            "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
            "update_time_ms": round(1000 * self.update_weights_timer.mean, 3),
            "opt_peak_throughput": round(self.grad_timer.mean_throughput, 3),
            "opt_samples": round(self.grad_timer.mean_units_processed, 3),
        })
Пример #12
0
class SyncBatchReplayOptimizer(PolicyOptimizer):
    """Variant of the sync replay optimizer that replays entire batches.

    This enables RNN support. Does not currently support prioritization."""
    @override(PolicyOptimizer)
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              train_batch_size=32):
        self.replay_starts = learning_starts
        self.max_buffer_size = buffer_size
        self.train_batch_size = train_batch_size
        assert self.max_buffer_size >= self.replay_starts

        # List of buffered sample batches
        self.replay_buffer = []
        self.buffer_size = 0

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batches = ray.get(
                    [e.sample.remote() for e in self.remote_evaluators])
            else:
                batches = [self.local_evaluator.sample()]

            # Handle everything as if multiagent
            tmp = []
            for batch in batches:
                if isinstance(batch, SampleBatch):
                    batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                            batch.count)
                tmp.append(batch)
            batches = tmp

            for batch in batches:
                self.replay_buffer.append(batch)
                self.num_steps_sampled += batch.count
                self.buffer_size += batch.count
                while self.buffer_size > self.max_buffer_size:
                    evicted = self.replay_buffer.pop(0)
                    self.buffer_size -= evicted.count

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = [random.choice(self.replay_buffer)]
        while sum(s.count for s in samples) < self.train_batch_size:
            samples.append(random.choice(self.replay_buffer))
        samples = SampleBatch.concat_samples(samples)
        with self.grad_timer:
            info_dict = self.local_evaluator.compute_apply(samples)
            for policy_id, info in info_dict.items():
                if "stats" in info:
                    self.learner_stats[policy_id] = info["stats"]
            self.grad_timer.push_units_processed(samples.count)
        self.num_steps_trained += samples.count
Пример #13
0
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that policy evaluators return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(self,
                 local_evaluator,
                 remote_evaluators,
                 learning_starts=1000,
                 buffer_size=10000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 schedule_max_timesteps=100000,
                 beta_annealing_fraction=0.2,
                 final_prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 train_batch_size=32,
                 sample_batch_size=4):
        PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(schedule_max_timesteps *
                                   beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray_get_and_free(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            info_dict = self.local_evaluator.learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample(
                         self.train_batch_size,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample(self.train_batch_size)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)
Пример #14
0
class LocalSyncReplayOptimizer(Optimizer):
    """Variant of the local sync optimizer that supports replay (for DQN)."""
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              prioritized_replay=True,
              prioritized_replay_alpha=0.6,
              prioritized_replay_beta=0.4,
              prioritized_replay_eps=1e-6,
              train_batch_size=32,
              sample_batch_size=4):

        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()

        # Set up replay buffer
        if prioritized_replay:
            self.replay_buffer = PrioritizedReplayBuffer(
                buffer_size, alpha=prioritized_replay_alpha)
        else:
            self.replay_buffer = ReplayBuffer(buffer_size)

        assert buffer_size >= self.replay_starts

    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()
            for row in batch.rows():
                self.replay_buffer.add(row["obs"], row["actions"],
                                       row["rewards"], row["new_obs"],
                                       row["dones"], row["weights"])

        if len(self.replay_buffer) >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    def _optimize(self):
        with self.replay_timer:
            if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
                (obses_t, actions, rewards, obses_tp1, dones, weights,
                 batch_indexes) = self.replay_buffer.sample(
                     self.train_batch_size, beta=self.prioritized_replay_beta)
            else:
                (obses_t, actions, rewards, obses_tp1,
                 dones) = self.replay_buffer.sample(self.train_batch_size)
                weights = np.ones_like(rewards)
                batch_indexes = -np.ones_like(rewards)

            samples = SampleBatch({
                "obs": obses_t,
                "actions": actions,
                "rewards": rewards,
                "new_obs": obses_tp1,
                "dones": dones,
                "weights": weights,
                "batch_indexes": batch_indexes
            })

        with self.grad_timer:
            td_error = self.local_evaluator.compute_apply(samples)
            new_priorities = (np.abs(td_error) + self.prioritized_replay_eps)
            if isinstance(self.replay_buffer, PrioritizedReplayBuffer):
                self.replay_buffer.update_priorities(samples["batch_indexes"],
                                                     new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def stats(self):
        return dict(
            Optimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
            })
Пример #15
0
class SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    evaluators, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote evaluators.
    """
    def __init__(self,
                 local_evaluator,
                 remote_evaluators,
                 num_sgd_iter=1,
                 train_batch_size=1):
        PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators)

        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.train_batch_size = train_batch_size
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.learn_on_batch(samples)
                self.learner_stats = get_learner_stats(fetches)
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms":
                round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms":
                round(1000 * self.grad_timer.mean, 3),
                "update_time_ms":
                round(1000 * self.update_weights_timer.mean, 3),
                "opt_peak_throughput":
                round(self.grad_timer.mean_throughput, 3),
                "sample_peak_throughput":
                round(self.sample_timer.mean_throughput, 3),
                "opt_samples":
                round(self.grad_timer.mean_units_processed, 3),
                "learner":
                self.learner_stats,
            })
class SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    workers, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote workers.
    """
    def __init__(self,
                 workers,
                 num_sgd_iter=1,
                 train_batch_size=1,
                 sgd_minibatch_size=0,
                 standardize_fields=frozenset([])):
        PolicyOptimizer.__init__(self, workers)

        self.update_weights_timer = TimerStat()
        self.standardize_fields = standardize_fields
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.sgd_minibatch_size = sgd_minibatch_size
        self.train_batch_size = train_batch_size
        self.learner_stats = {}
        self.policies = dict(
            self.workers.local_worker().foreach_trainable_policy(lambda p, i:
                                                                 (i, p)))
        logger.debug("Policies to train: {}".format(self.policies))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray.get([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            fetches = do_minibatch_sgd(samples, self.policies,
                                       self.workers.local_worker(),
                                       self.num_sgd_iter,
                                       self.sgd_minibatch_size,
                                       self.standardize_fields)
        self.grad_timer.push_units_processed(samples.count)

        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches
        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms":
                round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms":
                round(1000 * self.grad_timer.mean, 3),
                "update_time_ms":
                round(1000 * self.update_weights_timer.mean, 3),
                "opt_peak_throughput":
                round(self.grad_timer.mean_throughput, 3),
                "sample_peak_throughput":
                round(self.sample_timer.mean_throughput, 3),
                "opt_samples":
                round(self.grad_timer.mean_units_processed, 3),
                "learner":
                self.learner_stats,
            })
class SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    workers, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote workers.
    """

    def __init__(self,
                 workers,
                 num_sgd_iter=1,
                 train_batch_size=1,
                 sgd_minibatch_size=0,
                 standardize_fields=frozenset([])):
        PolicyOptimizer.__init__(self, workers)

        self.update_weights_timer = TimerStat()
        self.standardize_fields = standardize_fields
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.sgd_minibatch_size = sgd_minibatch_size
        self.train_batch_size = train_batch_size
        self.learner_stats = {}
        self.policies = dict(self.workers.local_worker()
                             .foreach_trainable_policy(lambda p, i: (i, p)))
        logger.debug("Policies to train: {}".format(self.policies))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        # Handle everything as if multiagent
        if isinstance(samples, SampleBatch):
            samples = MultiAgentBatch({
                DEFAULT_POLICY_ID: samples
            }, samples.count)

        fetches = {}
        with self.grad_timer:
            for policy_id, policy in self.policies.items():
                if policy_id not in samples.policy_batches:
                    continue

                batch = samples.policy_batches[policy_id]
                for field in self.standardize_fields:
                    value = batch[field]
                    standardized = (value - value.mean()) / max(
                        1e-4, value.std())
                    batch[field] = standardized

                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    for minibatch in self._minibatches(batch):
                        batch_fetches = (
                            self.workers.local_worker().learn_on_batch(
                                MultiAgentBatch({
                                    policy_id: minibatch
                                }, minibatch.count)))[policy_id]
                        for k, v in batch_fetches[LEARNER_STATS_KEY].items():
                            iter_extra_fetches[k].append(v)
                    logger.debug("{} {}".format(i,
                                                _averaged(iter_extra_fetches)))
                fetches[policy_id] = _averaged(iter_extra_fetches)

        self.grad_timer.push_units_processed(samples.count)
        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches
        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return self.learner_stats

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "sample_peak_throughput": round(
                    self.sample_timer.mean_throughput, 3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _minibatches(self, samples):
        if not self.sgd_minibatch_size:
            yield samples
            return

        if isinstance(samples, MultiAgentBatch):
            raise NotImplementedError(
                "Minibatching not implemented for multi-agent in simple mode")

        if "state_in_0" in samples.data:
            logger.warning("Not shuffling RNN data for SGD in simple mode")
        else:
            samples.shuffle()

        i = 0
        slices = []
        while i < samples.count:
            slices.append((i, i + self.sgd_minibatch_size))
            i += self.sgd_minibatch_size
        random.shuffle(slices)

        for i, j in slices:
            yield samples.slice(i, j)
Пример #18
0
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that policy evaluators return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""

    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              prioritized_replay=True,
              prioritized_replay_alpha=0.6,
              prioritized_replay_beta=0.4,
              prioritized_replay_eps=1e-6,
              train_batch_size=32,
              sample_batch_size=4,
              clip_rewards=True):

        self.replay_starts = learning_starts
        self.prioritized_replay_beta = prioritized_replay_beta
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(
                    buffer_size,
                    alpha=prioritized_replay_alpha,
                    clip_rewards=clip_rewards)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size, clip_rewards)

        self.replay_buffers = collections.defaultdict(new_buffer)

        assert buffer_size >= self.replay_starts

    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({
                    DEFAULT_POLICY_ID: batch
                }, batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    if "weights" not in row:
                        row["weights"] = np.ones_like(row["rewards"])
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"], row["rewards"],
                        pack_if_needed(row["new_obs"]), row["dones"],
                        row["weights"])

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            info_dict = self.local_evaluator.compute_apply(samples)
            for policy_id, info in info_dict.items():
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (
                        np.abs(td_error) + self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample(
                         self.train_batch_size,
                         beta=self.prioritized_replay_beta)
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample(self.train_batch_size)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
            samples[policy_id] = SampleBatch({
                "obs": obses_t,
                "actions": actions,
                "rewards": rewards,
                "new_obs": obses_tp1,
                "dones": dones,
                "weights": weights,
                "batch_indexes": batch_indexes
            })
        return MultiAgentBatch(samples, self.train_batch_size)

    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
            })
Пример #19
0
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that policy evaluators return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""

    @override(PolicyOptimizer)
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              prioritized_replay=True,
              prioritized_replay_alpha=0.6,
              prioritized_replay_beta=0.4,
              schedule_max_timesteps=100000,
              beta_annealing_fraction=0.2,
              final_prioritized_replay_beta=0.4,
              prioritized_replay_eps=1e-6,
              train_batch_size=32,
              sample_batch_size=4):

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(
                schedule_max_timesteps * beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(
                    buffer_size, alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        assert buffer_size >= self.replay_starts

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({
                    DEFAULT_POLICY_ID: batch
                }, batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            info_dict = self.local_evaluator.learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                if "stats" in info:
                    self.learner_stats[policy_id] = info["stats"]
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (
                        np.abs(td_error) + self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample(
                         self.train_batch_size,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample(self.train_batch_size)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(
        self,
        workers,
        learning_starts=1000,
        buffer_size=10000,
        prioritized_replay=True,
        prioritized_replay_alpha=0.6,
        prioritized_replay_beta=0.4,
        prioritized_replay_eps=1e-6,
        final_prioritized_replay_beta=0.4,
        train_batch_size=32,
        before_learn_on_batch=None,
        synchronize_sampling=False,
        prioritized_replay_beta_annealing_timesteps=100000 * 0.2,
    ):
        """Initialize an sync replay optimizer.

        Args:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            final_prioritized_replay_beta (float): Final value of beta.
            train_batch_size (int): size of batches to learn on
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
            prioritized_replay_beta_annealing_timesteps (int): The timestep at
                which PR-beta annealing should end.
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts

        # Linearly annealing beta used in Rainbow paper, stopping at
        # `final_prioritized_replay_beta`.
        self.prioritized_replay_beta = PiecewiseSchedule(
            endpoints=[(0, prioritized_replay_beta),
                       (prioritized_replay_beta_annealing_timesteps,
                        final_prioritized_replay_beta)],
            outside_value=final_prioritized_replay_beta,
            framework=None)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batch = SampleBatch.concat_samples(
                    ray_get_and_free([
                        e.sample.remote()
                        for e in self.workers.remote_workers()
                    ]))
            else:
                batch = self.workers.local_worker().sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        idxes = None
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)

    def save(self):
        f = open(
            "/home/yunke/prl_proj/panda_ws/src/franka_cal_sim/python/replay_buffer.txt",
            "w")
        for policy_id, replay_buffer in self.replay_buffers.items():
            for data in replay_buffer._storage:
                obs_t, action, reward, obs_tp1, done, weight = data
                obs_s = ','.join([str(v) for v in obs_t])
                action = ','.join([str(v) for v in action])
                obs_tp1 = ','.join([str(v) for v in obs_tp1])
                f.write("%s\t%s\t%s\t%s\t%s\t%s\n" %
                        (obs_s, action, reward, obs_tp1, done, weight))
        f.close()

    def restore(self):
        f = open(
            "/home/yunke/prl_proj/panda_ws/src/franka_cal_sim/python/replay_buffer.txt",
            "r")
        obs, actions, rewards, next_obs, terminals, weights = [],[],[],[],[],[]
        for line in f:
            cols = line.strip().split('\t')
            obs_t = np.array([float(v) for v in cols[0].split(',')])
            obs.append(obs_t)
            action = np.array([float(v) for v in cols[1].split(',')])
            actions.append(action)
            rewards.append(float(cols[2]))
            obs_tp1 = np.array([float(v) for v in cols[3].split(',')])
            next_obs.append(obs_tp1)
            terminals.append(bool(cols[4]))
            weights.append(float(cols[5]))

        batch = SampleBatch({
            "obs": obs,
            "actions": actions,
            "rewards": rewards,
            "new_obs": next_obs,
            "dones": terminals,
            "weights": weights
        })

        for i in range(obs_s.shape[0]):
            self.replay_buffers[policy_id].add(pack_if_needed(obs_s[i]),
                                               actions[i], rewards[i],
                                               pack_if_needed(new_obs[i]),
                                               dones[i])
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(self,
                 workers,
                 config,
                 learning_starts=1000,
                 buffer_size=50000,
                 prioritized_replay=True,
                 prioritized_replay_alpha=0.6,
                 prioritized_replay_beta=0.4,
                 prioritized_replay_eps=1e-6,
                 schedule_max_timesteps=100000,
                 beta_annealing_fraction=0.2,
                 final_prioritized_replay_beta=0.4,
                 train_batch_size=32,
                 sample_batch_size=4,
                 before_learn_on_batch=None,
                 synchronize_sampling=False):
        """Initialize an sync replay optimizer.

        Arguments:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            schedule_max_timesteps (int): number of timesteps in the schedule
            beta_annealing_fraction (float): fraction of schedule to anneal
                beta over
            final_prioritized_replay_beta (float): final value of beta
            train_batch_size (int): size of batches to learn on
            sample_batch_size (int): size of batches to sample from workers
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts
        # linearly annealing beta used in Rainbow paper
        self.prioritized_replay_beta = LinearSchedule(
            schedule_timesteps=int(schedule_max_timesteps *
                                   beta_annealing_fraction),
            initial_p=prioritized_replay_beta,
            final_p=final_prioritized_replay_beta)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}
        '''Attention Info'''
        self.traffic_light_node_dict = {}
        self.record_dir = '/home/skylark/PycharmRemote/Gamma-Reward-Perfect/record/' + config[
            "env_config"]["Name"]
        self.read_traffic_light_node_dict()
        self.tmp_dic = self.traffic_light_node_dict['intersection_1_1'][
            'inter_id_to_index']
        # -------------------------------------------
        '''
        For compare reward change 
        '''
        self.raw_reward_store = {}
        self.Reward_store = {}
        for inter_id in self.tmp_dic:
            self.raw_reward_store[inter_id] = []
            self.Reward_store[inter_id] = []
        # self.j_store = 0
        # ------------------------------
        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))
        '''
        For Gamma Reward by Skylark
        '''
        self.memory_thres = config["env_config"]["memory_thres"]
        self.num_steps_presampled = 0
        self.gamma = 0.5
        self.index = 0
        self.punish_coeff = 1.5
        self.config = config
        # Set up replay buffer
        if prioritized_replay:

            def pre_new_buffer():
                return PrioritizedReplayBuffer(buffer_size + self.memory_thres,
                                               alpha=prioritized_replay_alpha)
        else:

            def pre_new_buffer():
                return ReplayBuffer(buffer_size + self.memory_thres)

        self.pre_replay_buffers = collections.defaultdict(pre_new_buffer)
        # ------------------------------------------

        # '''
        # For Attention Reward by Skylark
        # '''
        # sa_size = [(15, 8), (15, 8), (15, 8), (15, 8), (15, 8), (15, 8)]
        # critic_hidden_dim = 128
        # attend_heads = 4
        # q_lr = 0.01
        # self.attention = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim,
        #                                  attend_heads=attend_heads)
        # self.target_attention = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim,
        #                                         attend_heads=attend_heads)
        # hard_update(self.target_attention, self.attention)
        # self.attention_optimizer = Adam(self.attention.parameters(), lr=q_lr,
        #                                 weight_decay=1e-3)
        # self.niter = 0
        # ------------------------------------------------------------------

    @override(PolicyOptimizer)
    def step(self, attention_score_dic=None):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batch = SampleBatch.concat_samples(
                    ray_get_and_free([
                        e.sample.remote()
                        for e in self.workers.remote_workers()
                    ]))
            else:
                batch = self.workers.local_worker().sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)
            '''
            For Gamma Reward by LJJ (You can check the local history for changing)
            '''
            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.pre_replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

            if self.num_steps_presampled >= self.memory_thres:
                self._preprocess(batch, attention_score_dic)

            self.num_steps_presampled += batch.count

        # -----------------------------------------------------------------------

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        idxes = None
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)

    def _preprocess(self, batch, attention_score_dic=None):
        """
        Self-defined function: For Gamma Reward Replay Buffer Amendment
        :param batch: SampleBatch class,
        :param attention_score_dic: For transferring Attention score calculated by target attention layers
        :return: return Amendatory Replay Buffer
        """
        global j_store
        for policy_id, s in batch.policy_batches.items():
            storage = list(self.pre_replay_buffers[policy_id]._storage)
            index = len(storage) - self.memory_thres - 1
            tmp_buffer = storage.copy()
            current_intersection = self.inter_num_2_id(
                policy_id_handle(policy_id))
            '''
            For comparing the change of rewards 
            '''
            # ------------------------------
            while index > self.index - 1:
                obs = storage[index][0]
                action = storage[index][1]
                reward = storage[index][2]
                new_obs = storage[index][3]
                done = storage[index][4]
                p_value = 0

                all_roads_path_2dlst = np.array(
                    self.config['env_config']['lane_phase_info']
                    [current_intersection]['phase_roadLink_mapping'][action +
                                                                     1])
                all_end_roads = self.config['env_config']['lane_phase_info'][
                    current_intersection]['end_lane']
                permitted_end_roads = np.unique([
                    all_roads_path_2dlst[lane_index, 1] for lane_index,
                    start_lane in enumerate(all_roads_path_2dlst[:, 0])
                    if start_lane[-1] != '2'
                ])
                dis_permitted_end_roads = list(
                    set(all_end_roads).difference(
                        set(list(permitted_end_roads))))

                # Take neighbors into account
                for other_policy_id, s in batch.policy_batches.items():
                    other_intersection = self.inter_num_2_id(
                        policy_id_handle(other_policy_id))
                    if other_policy_id != policy_id and other_intersection in \
                            self.traffic_light_node_dict[current_intersection]['neighbor_ENWS']:
                        other_storage = self.pre_replay_buffers[
                            other_policy_id]._storage
                        '''
                        For corresponding lane in a neighbouring intersection, m_2 represents the waiting count in
                        t+n time step and m_1 for t step.  m_2-m_1/m_1
                        '''
                        road_index_dict = {
                            road: road_index
                            for road_index, road in
                            enumerate(self.config['env_config']['road_sort']
                                      [other_intersection])
                        }

                        # differential = np.max(
                        #     np.array(other_storage[index + self.memory_thres - 1:index + self.memory_thres])[:,
                        #     2]) / other_storage[index][2]
                        for road in road_index_dict.keys():
                            if road in all_end_roads:
                                if road in permitted_end_roads:
                                    I_a = -1
                                elif road in dis_permitted_end_roads:
                                    I_a = 0
                                else:
                                    print('wrong')
                                road_index = road_index_dict[road]

                                m_1 = np.array(
                                    other_storage[index])[0][road_index]
                                m_2 = np.mean([
                                    other_storage[index + self.memory_thres -
                                                  2][0][road_index],
                                    other_storage[index + self.memory_thres -
                                                  1][0][road_index]
                                ])
                                if m_2 - m_1 == 0 or m_1 == 0:
                                    differential = 0
                                else:
                                    differential = m_2 - m_1 / m_1  # m_2 = 0, m_1 != 0 -> differential = -1
                                    if differential > 1:
                                        differential = 0

                                p_value += m_1 * np.tanh(differential) * I_a

                if self.config['env_config']['Gamma_Reward']:
                    p_reward = reward + self.gamma * p_value
                    # print('Reward: ' + str(Reward) + ',' + 'reward: ' + str(reward))
                    if p_reward <= -20:
                        p_reward = -20
                    # print(Reward)
                else:
                    p_reward = reward
                '''
                For compare reward change 
                '''
                # if 50 < j_store < 100:
                #     self.raw_reward_store[self.inter_num_2_id(policy_id_handle(policy_id))].append(reward)
                #     self.Reward_store[self.inter_num_2_id(policy_id_handle(policy_id))].append(Reward)

                # ------------------------------
                tmp_buffer[index] = list(storage[index])
                tmp_buffer[index][2] = p_reward
                index -= 1

            for i in range(self.index, len(tmp_buffer) - self.memory_thres):
                self.replay_buffers[policy_id].add(obs_t=tmp_buffer[i][0],
                                                   action=tmp_buffer[i][1],
                                                   reward=tmp_buffer[i][2],
                                                   obs_tp1=tmp_buffer[i][3],
                                                   done=tmp_buffer[i][4],
                                                   weight=None)

        # Reward MDP
        index = len(storage) - self.memory_thres - 1
        while index > self.index - 1:
            for policy_id, s in batch.policy_batches.items():
                current_intersection = self.inter_num_2_id(
                    policy_id_handle(policy_id))
                storage = list(self.replay_buffers[policy_id]._storage)
                p_reward = storage[index][2]
                sum_other_reward = 0
                for other_policy_id, s in batch.policy_batches.items():
                    other_intersection = self.inter_num_2_id(
                        policy_id_handle(other_policy_id))
                    if other_policy_id != policy_id and other_intersection in \
                            self.traffic_light_node_dict[current_intersection]['neighbor_ENWS']:
                        other_storage = self.replay_buffers[
                            other_policy_id]._storage
                        pre_other_storage = self.pre_replay_buffers[
                            other_policy_id]._storage
                        if index + self.memory_thres >= len(other_storage):
                            sum_other_reward = 0
                        else:
                            sum_other_reward += np.tanh(
                                other_storage[index + self.memory_thres][2] /
                                pre_other_storage[index +
                                                  self.memory_thres][2] -
                                self.punish_coeff)
                Reward = p_reward + self.gamma * sum_other_reward
                self.replay_buffers[policy_id]._storage[index] = list(
                    self.replay_buffers[policy_id]._storage[index])
                self.replay_buffers[policy_id]._storage[index][2] = Reward
                self.replay_buffers[policy_id]._storage[index] = tuple(
                    self.replay_buffers[policy_id]._storage[index])
            index -= 1

        j_store += 1
        self.index = len(storage) - self.memory_thres

        # if j_store == 100:
        #     print("Start recording the reward !!!!!!!!!!!")
        #     raw_reward_store_np = {}
        #     Reward_store_np = {}
        #     for inter_id in self.tmp_dic:
        #         raw_reward_store_np[inter_id] = np.array(self.raw_reward_store[inter_id])
        #         Reward_store_np[inter_id] = np.array(self.Reward_store[inter_id])
        #     raw_reward_store_pd = pd.DataFrame(dict((k, pd.Series(v)) for k, v in raw_reward_store_np.items()))
        #     Reward_store_pd = pd.DataFrame(dict((k, pd.Series(v)) for k, v in Reward_store_np.items()))
        #     raw_reward_store_pd.to_csv(os.path.join(self.record_dir, 'raw_reward_store_pd.csv'))
        #     Reward_store_pd.to_csv(os.path.join(self.record_dir, 'Reward_store_pd.csv'))

        # self.replay_buffers = storage[:self.index]

        self.num_steps_sampled = len(
            self.replay_buffers[policy_id]._storage)  # Any policy_id is OK

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

    def _sigmoid(self, x):
        return 1 / (1 + math.exp(-x))

    def read_traffic_light_node_dict(self):
        path_to_read = os.path.join(self.record_dir,
                                    'traffic_light_node_dict.conf')
        with open(path_to_read, 'r') as f:
            self.traffic_light_node_dict = eval(f.read())
            print("Read traffic_light_node_dict")

    def inter_num_2_id(self, num):
        return list(self.tmp_dic.keys())[list(
            self.tmp_dic.values()).index(num)]
class SyncBatchReplayOptimizer(PolicyOptimizer):
    """Variant of the sync replay optimizer that replays entire batches.

    This enables RNN support. Does not currently support prioritization."""

    @override(PolicyOptimizer)
    def _init(self,
              learning_starts=1000,
              buffer_size=10000,
              train_batch_size=32):
        self.replay_starts = learning_starts
        self.max_buffer_size = buffer_size
        self.train_batch_size = train_batch_size
        assert self.max_buffer_size >= self.replay_starts

        # List of buffered sample batches
        self.replay_buffer = []
        self.buffer_size = 0

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batches = ray.get(
                    [e.sample.remote() for e in self.remote_evaluators])
            else:
                batches = [self.local_evaluator.sample()]

            # Handle everything as if multiagent
            tmp = []
            for batch in batches:
                if isinstance(batch, SampleBatch):
                    batch = MultiAgentBatch({
                        DEFAULT_POLICY_ID: batch
                    }, batch.count)
                tmp.append(batch)
            batches = tmp

            for batch in batches:
                self.replay_buffer.append(batch)
                self.num_steps_sampled += batch.count
                self.buffer_size += batch.count
                while self.buffer_size > self.max_buffer_size:
                    evicted = self.replay_buffer.pop(0)
                    self.buffer_size -= evicted.count

        if self.num_steps_sampled >= self.replay_starts:
            return self._optimize()
        else:
            return {}

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = [random.choice(self.replay_buffer)]
        while sum(s.count for s in samples) < self.train_batch_size:
            samples.append(random.choice(self.replay_buffer))
        samples = SampleBatch.concat_samples(samples)
        with self.grad_timer:
            info_dict = self.local_evaluator.learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                if "stats" in info:
                    self.learner_stats[policy_id] = info["stats"]
            self.grad_timer.push_units_processed(samples.count)
        self.num_steps_trained += samples.count
        return info_dict
class SyncReplayOptimizer(PolicyOptimizer):
    """Variant of the local sync optimizer that supports replay (for DQN).

    This optimizer requires that rollout workers return an additional
    "td_error" array in the info return of compute_gradients(). This error
    term will be used for sample prioritization."""
    def __init__(
        self,
        workers,
        learning_starts=1000,
        buffer_size=10000,
        prioritized_replay=True,
        prioritized_replay_alpha=0.6,
        prioritized_replay_beta=0.4,
        prioritized_replay_eps=1e-6,
        final_prioritized_replay_beta=0.4,
        train_batch_size=32,
        before_learn_on_batch=None,
        synchronize_sampling=False,
        prioritized_replay_beta_annealing_timesteps=100000 * 0.2,
    ):
        """Initialize an sync replay optimizer.

        Args:
            workers (WorkerSet): all workers
            learning_starts (int): wait until this many steps have been sampled
                before starting optimization.
            buffer_size (int): max size of the replay buffer
            prioritized_replay (bool): whether to enable prioritized replay
            prioritized_replay_alpha (float): replay alpha hyperparameter
            prioritized_replay_beta (float): replay beta hyperparameter
            prioritized_replay_eps (float): replay eps hyperparameter
            final_prioritized_replay_beta (float): Final value of beta.
            train_batch_size (int): size of batches to learn on
            before_learn_on_batch (function): callback to run before passing
                the sampled batch to learn on
            synchronize_sampling (bool): whether to sample the experiences for
                all policies with the same indices (used in MADDPG).
            prioritized_replay_beta_annealing_timesteps (int): The timestep at
                which PR-beta annealing should end.
        """
        PolicyOptimizer.__init__(self, workers)

        self.replay_starts = learning_starts

        # Linearly annealing beta used in Rainbow paper, stopping at
        # `final_prioritized_replay_beta`.
        self.prioritized_replay_beta = PiecewiseSchedule(
            endpoints=[(0, prioritized_replay_beta),
                       (prioritized_replay_beta_annealing_timesteps,
                        final_prioritized_replay_beta)],
            outside_value=final_prioritized_replay_beta,
            framework=None)
        self.prioritized_replay_eps = prioritized_replay_eps
        self.train_batch_size = train_batch_size
        self.before_learn_on_batch = before_learn_on_batch
        self.synchronize_sampling = synchronize_sampling

        # Stats
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.replay_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.learner_stats = {}

        # Set up replay buffer
        if prioritized_replay:

            def new_buffer():
                return PrioritizedReplayBuffer(buffer_size,
                                               alpha=prioritized_replay_alpha)
        else:

            def new_buffer():
                return ReplayBuffer(buffer_size)

        self.replay_buffers = collections.defaultdict(new_buffer)

        if buffer_size < self.replay_starts:
            logger.warning("buffer_size={} < replay_starts={}".format(
                buffer_size, self.replay_starts))

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.workers.remote_workers():
                batch = SampleBatch.concat_samples(
                    ray_get_and_free([
                        e.sample.remote()
                        for e in self.workers.remote_workers()
                    ]))
            else:
                batch = self.workers.local_worker().sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "replay_time_ms": round(1000 * self.replay_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })

    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    # TODO(sven): This is currently structured differently for
                    #  torch/tf. Clean up these results/info dicts across
                    #  policies (note: fixing this in torch_policy.py will
                    #  break e.g. DDPPO!).
                    td_error = info.get("td_error",
                                        info["learner_stats"].get("td_error"))
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count

    def _replay(self):
        samples = {}
        idxes = None
        with self.replay_timer:
            for policy_id, replay_buffer in self.replay_buffers.items():
                if self.synchronize_sampling:
                    if idxes is None:
                        idxes = replay_buffer.sample_idxes(
                            self.train_batch_size)
                else:
                    idxes = replay_buffer.sample_idxes(self.train_batch_size)

                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    (obses_t, actions, rewards, obses_tp1, dones, weights,
                     batch_indexes) = replay_buffer.sample_with_idxes(
                         idxes,
                         beta=self.prioritized_replay_beta.value(
                             self.num_steps_trained))
                else:
                    (obses_t, actions, rewards, obses_tp1,
                     dones) = replay_buffer.sample_with_idxes(idxes)
                    weights = np.ones_like(rewards)
                    batch_indexes = -np.ones_like(rewards)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
        return MultiAgentBatch(samples, self.train_batch_size)
class SyncPhasicOptimizer(PolicyOptimizer):
    def __init__(self,
                 workers,
                 num_sgd_iter=1,
                 train_batch_size=1,
                 sgd_minibatch_size=0,
                 standardize_fields=frozenset([]),
                 aux_loss_every_k=16,
                 aux_loss_num_sgd_iter=9,
                 aux_loss_start_after_num_steps=0):
        PolicyOptimizer.__init__(self, workers)

        self.update_weights_timer = TimerStat()
        self.standardize_fields = standardize_fields
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.sgd_minibatch_size = sgd_minibatch_size
        self.train_batch_size = train_batch_size
        self.learner_stats = {}
        self.policies = dict(
            self.workers.local_worker().foreach_trainable_policy(lambda p, i:
                                                                 (i, p)))
        logger.debug("Policies to train: {}".format(self.policies))

        self.aux_loss_every_k = aux_loss_every_k
        self.aux_loss_num_sgd_iter = aux_loss_num_sgd_iter
        self.aux_loss_start_after_num_steps = aux_loss_start_after_num_steps
        self.memory = []
        # Assert that train batch size is divisible by sgd minibatch size to make populating
        # policy logits simpler.
        assert train_batch_size % sgd_minibatch_size == 0, (
            f"train_batch_size: {train_batch_size}"
            f"sgd_minibatch_size: {sgd_minibatch_size}")

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.workers.remote_workers():
                weights = ray.put(self.workers.local_worker().get_weights())
                for e in self.workers.remote_workers():
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.workers.remote_workers():
                    samples.extend(
                        ray_get_and_free([
                            e.sample.remote()
                            for e in self.workers.remote_workers()
                        ]))
                else:
                    samples.append(self.workers.local_worker().sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        # Unfortunate to have to hack it like this, but not sure how else to do it.
        # Setting the phase to zeros results in policy optimization, and to ones results in aux optimization.
        # These have to be added prior to the policy sgd.
        samples["phase"] = np.zeros(samples.count)

        with self.grad_timer:
            fetches = do_minibatch_sgd(samples, self.policies,
                                       self.workers.local_worker(),
                                       self.num_sgd_iter,
                                       self.sgd_minibatch_size,
                                       self.standardize_fields)
        self.grad_timer.push_units_processed(samples.count)

        if len(fetches) == 1 and DEFAULT_POLICY_ID in fetches:
            self.learner_stats = fetches[DEFAULT_POLICY_ID]
        else:
            self.learner_stats = fetches

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count

        if self.num_steps_sampled > self.aux_loss_start_after_num_steps:
            # Add samples to the memory to be provided to the aux loss.
            self._remove_unnecessary_data(samples)
            self.memory.append(samples)

            # Optionally run the aux optimization.
            if len(self.memory) >= self.aux_loss_every_k:
                samples = SampleBatch.concat_samples(self.memory)
                self._add_policy_logits(samples)
                # Ones indicate aux phase.
                samples["phase"] = np.ones_like(samples["phase"])
                do_minibatch_sgd(samples, self.policies,
                                 self.workers.local_worker(),
                                 self.aux_loss_num_sgd_iter,
                                 self.sgd_minibatch_size, [])
                self.memory = []

        return self.learner_stats

    def _remove_unnecessary_data(self,
                                 samples,
                                 keys_to_keep=set([
                                     SampleBatch.CUR_OBS,
                                     SampleBatch.PREV_ACTIONS,
                                     SampleBatch.PREV_REWARDS,
                                     "phase",
                                     Postprocessing.VALUE_TARGETS,
                                 ])):
        for key in list(samples.keys()):
            if key not in keys_to_keep:
                del samples.data[key]

    def _add_policy_logits(self, samples):
        with torch.no_grad():
            policy = self.policies["default_policy"]

            all_logits = []
            for start in range(0, samples.count, self.sgd_minibatch_size):
                end = start + self.sgd_minibatch_size

                batch = samples.slice(start, end)
                batch["is_training"] = False
                batch = policy._lazy_tensor_dict(batch)

                logits, _ = policy.model.from_batch(batch)
                all_logits.append(logits.detach().cpu().numpy())

            samples["pre_aux_logits"] = np.concatenate(all_logits)

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms":
                round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms":
                round(1000 * self.grad_timer.mean, 3),
                "update_time_ms":
                round(1000 * self.update_weights_timer.mean, 3),
                "opt_peak_throughput":
                round(self.grad_timer.mean_throughput, 3),
                "sample_peak_throughput":
                round(self.sample_timer.mean_throughput, 3),
                "opt_samples":
                round(self.grad_timer.mean_units_processed, 3),
                "learner":
                self.learner_stats,
            })
Пример #25
0
class SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    evaluators, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote evaluators.
    """

    @override(PolicyOptimizer)
    def _init(self, num_sgd_iter=1, train_batch_size=1):
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.train_batch_size = train_batch_size
        self.learner_stats = {}

    @override(PolicyOptimizer)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray.get([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.compute_apply(samples)
                if "stats" in fetches:
                    self.learner_stats = fetches["stats"]
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches

    @override(PolicyOptimizer)
    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms": round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms": round(1000 * self.grad_timer.mean, 3),
                "update_time_ms": round(1000 * self.update_weights_timer.mean,
                                        3),
                "opt_peak_throughput": round(self.grad_timer.mean_throughput,
                                             3),
                "sample_peak_throughput": round(
                    self.sample_timer.mean_throughput, 3),
                "opt_samples": round(self.grad_timer.mean_units_processed, 3),
                "learner": self.learner_stats,
            })
class ModelBased_SyncSamplesOptimizer(PolicyOptimizer):
    """A simple synchronous RL optimizer.

    In each step, this optimizer pulls samples from a number of remote
    evaluators, concatenates them, and then updates a local model. The updated
    model weights are then broadcast to all remote evaluators.
    """
    def _init(self, num_sgd_iter=1, train_batch_size=1):
        self.update_weights_timer = TimerStat()
        self.sample_timer = TimerStat()
        self.grad_timer = TimerStat()
        self.throughput = RunningStat()
        self.num_sgd_iter = num_sgd_iter
        self.train_batch_size = train_batch_size
        self.learner_stats = {}

    def set_env_model(self, env_model):
        # TODO: add validation check
        self.env_model = env_model

    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray.get([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)
            # print("\n\nhkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples", samples.keys())
            # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples obs", samples["obs"].shape)
            # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples new_obs", samples["new_obs"].shape)
            # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples actions", samples["actions"].shape)
            # import numpy
            # transition_state = numpy.stack([samples["obs"], samples["new_obs"]],  axis=2)
            # print("hkkkkk ", transition_state.shape)
            # print("hkkkkk ", transition_state.reshape(transition_state.shape[0], -1).shape)

            # print("samples[obs]", samples["obs"])
            # print("samples[actions]", samples["actions"])

        with self.grad_timer:

            new_samples = self.env_model.process(samples)
            # print("new_samples[obs]", new_samples["obs"])
            # print("new_samples[actions]", new_samples["actions"])
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.compute_apply(new_samples)
                if "stats" in fetches:
                    self.learner_stats = fetches["stats"]
                if self.num_sgd_iter > 1:
                    print(i, fetches)
            # self.grad_timer.push_units_processed(new_samples.count)
            self.grad_timer.push_units_processed(len(samples["obs"]))

        # self.num_steps_sampled += new_samples.count
        # self.num_steps_trained += new_samples.count
        self.num_steps_sampled += len(samples["obs"])
        self.num_steps_trained += len(samples["obs"])
        return fetches

    def stats(self):
        return dict(
            PolicyOptimizer.stats(self), **{
                "sample_time_ms":
                round(1000 * self.sample_timer.mean, 3),
                "grad_time_ms":
                round(1000 * self.grad_timer.mean, 3),
                "update_time_ms":
                round(1000 * self.update_weights_timer.mean, 3),
                "opt_peak_throughput":
                round(self.grad_timer.mean_throughput, 3),
                "sample_peak_throughput":
                round(self.sample_timer.mean_throughput, 3),
                "opt_samples":
                round(self.grad_timer.mean_units_processed, 3),
                "learner":
                self.learner_stats,
            })