Ejemplo n.º 1
0
 def __call__(self, samples: SampleBatchType):
     _check_sample_batch_type(samples)
     metrics = LocalIterator.get_metrics()
     with metrics.timers[COMPUTE_GRADS_TIMER]:
         grad, info = self.workers.local_worker().compute_gradients(samples)
     metrics.info[LEARNER_INFO] = get_learner_stats(info)
     return grad, samples.count
Ejemplo n.º 2
0
 def step(self):
     with self.overall_timer:
         with self.queue_timer:
             ra, replay = self.inqueue.get()
         if replay is not None:
             prio_dict = {}
             with self.grad_timer:
                 grad_out = self.local_worker.learn_on_batch(replay)
                 for pid, info in grad_out.items():
                     td_error = info.get(
                         "td_error",
                         info[LEARNER_STATS_KEY].get("td_error"))
                     # Switch off auto-conversion from numpy to torch/tf
                     # tensors for the indices. This may lead to errors
                     # when sent to the buffer for processing
                     # (may get manipulated if they are part of a tensor).
                     replay.policy_batches[pid].set_get_interceptor(None)
                     prio_dict[pid] = (
                         replay.policy_batches[pid].get("batch_indexes"),
                         td_error)
                     self.stats[pid] = get_learner_stats(info)
                 self.grad_timer.push_units_processed(replay.count)
             self.outqueue.put((ra, prio_dict, replay.count))
         self.learner_queue_size.push(self.inqueue.qsize())
         self.weights_updated = True
         self.overall_timer.push_units_processed(replay and replay.count
                                                 or 0)
    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            if self.before_learn_on_batch:
                samples = self.before_learn_on_batch(
                    samples,
                    self.workers.local_worker().policy_map,
                    self.train_batch_size)
            info_dict = self.workers.local_worker().learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    # TODO(sven): This is currently structured differently for
                    #  torch/tf. Clean up these results/info dicts across
                    #  policies (note: fixing this in torch_policy.py will
                    #  break e.g. DDPPO!).
                    td_error = info.get("td_error",
                                        info["learner_stats"].get("td_error"))
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count
Ejemplo n.º 4
0
 def __call__(self,
              batch: SampleBatchType) -> (SampleBatchType, List[dict]):
     _check_sample_batch_type(batch)
     metrics = _get_shared_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0:
             w = self.workers.local_worker()
             info = do_minibatch_sgd(
                 batch, {p: w.get_policy(p)
                         for p in self.policies}, w, self.num_sgd_iter,
                 self.sgd_minibatch_size, [])
             # TODO(ekl) shouldn't be returning learner stats directly here
             metrics.info[LEARNER_INFO] = info
         else:
             info = self.workers.local_worker().learn_on_batch(batch)
             metrics.info[LEARNER_INFO] = get_learner_stats(info)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights(
                 self.policies))
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights, _get_global_vars())
     # Also update global vars of the local worker.
     self.workers.local_worker().set_global_vars(_get_global_vars())
     return batch, info
Ejemplo n.º 5
0
    def improve_policy(self, num_improvements: int) -> Dict[str, float]:
        """Call the policy to perform policy improvement using the augmented replay.

        Args:
            num_improvements: Number of times to call `policy.learn_on_batch`

        Returns:
            A dictionary of training and exploration statistics
        """
        policy = self.get_policy()
        batch_size = self.config["train_batch_size"]
        env_batch_size = int(batch_size * self.config["real_data_ratio"])
        model_batch_size = batch_size - env_batch_size

        stats = {}
        for _ in range(num_improvements):
            samples = []
            if env_batch_size:
                samples += [self.replay.sample(env_batch_size)]
            if model_batch_size:
                samples += [self.virtual_replay.sample(model_batch_size)]
            batch = SampleBatch.concat_samples(samples)
            stats = get_learner_stats(policy.learn_on_batch(batch))
            self.tracker.num_steps_trained += batch.count

        stats.update(policy.get_exploration_info())
        return stats
Ejemplo n.º 6
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray.get([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.learn_on_batch(samples)
                self.learner_stats = get_learner_stats(fetches)
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches
Ejemplo n.º 7
0
    def __call__(self, data_tuple):
        # Metaupdate Step
        samples = data_tuple[0]
        adapt_metrics_dict = data_tuple[1]
        for i in range(self.maml_optimizer_steps):
            fetches = self.workers.local_worker().learn_on_batch(samples)
        fetches = get_learner_stats(fetches)

        # Sync workers with meta policy
        self.workers.sync_weights()

        # Set worker tasks
        set_worker_tasks(self.workers)

        # Update KLS
        def update(pi, pi_id):
            assert "inner_kl" not in fetches, (
                "inner_kl should be nested under policy id key", fetches)
            if pi_id in fetches:
                assert "inner_kl" in fetches[pi_id], (fetches, pi_id)
                pi.update_kls(fetches[pi_id]["inner_kl"])
            else:
                logger.warning("No data for {}, not updating kl".format(pi_id))

        self.workers.local_worker().foreach_trainable_policy(update)

        # Modify Reporting Metrics
        metrics = _get_shared_metrics()
        metrics.info[LEARNER_INFO] = fetches
        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count

        res = self.metric_gen.__call__(None)
        res.update(adapt_metrics_dict)

        return res
Ejemplo n.º 8
0
 def __call__(self, item):
     (grads, info), count = item
     metrics = LocalIterator.get_metrics()
     metrics.counters[STEPS_SAMPLED_COUNTER] += count
     metrics.info[LEARNER_INFO] = get_learner_stats(info)
     metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() -
                                          self.fetch_start_time)
     return grads, count
Ejemplo n.º 9
0
def test_update_interval(policy, model_update_interval, samples):
    expected = 0

    for i in range(1, model_update_interval * 10 + 1):
        info = policy.learn_on_batch(samples)
        info = get_learner_stats(info)
        assert policy._learn_calls == i
        expected += 1 if (i % model_update_interval == 0 or i == 1) else 0
        assert info["model_epochs"] == expected
Ejemplo n.º 10
0
 def _sgd_step(self):
     samples = [random.choice(self.replay_buffer)]
     while sum(s.count for s in samples) < self.train_batch_size:
         samples.append(random.choice(self.replay_buffer))
     samples = SampleBatch.concat_samples(samples)
     info_dict = self.workers.local_worker().learn_on_batch(samples)
     for policy_id, info in info_dict.items():
         self.learner_stats[policy_id] = get_learner_stats(info)
     self.num_steps_trained += samples.count
     return info_dict
Ejemplo n.º 11
0
    def step(self):
        with self.queue_timer:
            batch, _ = self.minibatch_buffer.get()

        with self.grad_timer:
            fetches = self.local_worker.learn_on_batch(batch)
            self.weights_updated = True
            self.stats = get_learner_stats(fetches)

        self.outqueue.put(batch.count)
        self.learner_queue_size.push(self.inqueue.qsize())
Ejemplo n.º 12
0
 def _optimize(self):
     samples = [random.choice(self.replay_buffer)]
     while sum(s.count for s in samples) < self.train_batch_size:
         samples.append(random.choice(self.replay_buffer))
     samples = SampleBatch.concat_samples(samples)
     with self.grad_timer:
         info_dict = self.local_evaluator.learn_on_batch(samples)
         for policy_id, info in info_dict.items():
             self.learner_stats[policy_id] = get_learner_stats(info)
         self.grad_timer.push_units_processed(samples.count)
     self.num_steps_trained += samples.count
     return info_dict
Ejemplo n.º 13
0
 def step(self):
     with self.queue_timer:
         batch, _ = self.minibatch_buffer.get()
     with self.grad_timer:
         fetches = self.local_worker.learn_on_batch(batch)
         self.weights_updated = True
         self.stats.update(get_learner_stats(fetches))
         self.stats["train_timesteps"] += batch.count
         self.num_steps += 1
         self.stats["update_steps"] = self.num_steps
     self.outqueue.put(batch.count)
     self.learner_queue_size.push(self.inqueue.qsize())
Ejemplo n.º 14
0
    def step(self):
        assert self.loader_thread.is_alive()
        with self.load_wait_timer:
            opt, released = self.minibatch_buffer.get()

        with self.grad_timer:
            fetches = opt.optimize(self.sess, 0)
            self.weights_updated = True
            self.stats = get_learner_stats(fetches)

        if released:
            self.idle_optimizers.put(opt)

        self.outqueue.put(opt.num_tuples_loaded)
        self.learner_queue_size.push(self.inqueue.qsize())
Ejemplo n.º 15
0
 def __call__(self, batch: SampleBatchType) -> List[dict]:
     _check_sample_batch_type(batch)
     metrics = LocalIterator.get_metrics()
     learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
     with learn_timer:
         info = self.workers.local_worker().learn_on_batch(batch)
         learn_timer.push_units_processed(batch.count)
     metrics.counters[STEPS_TRAINED_COUNTER] += batch.count
     metrics.info[LEARNER_INFO] = get_learner_stats(info)
     if self.workers.remote_workers():
         with metrics.timers[WORKER_UPDATE_TIMER]:
             weights = ray.put(self.workers.local_worker().get_weights())
             for e in self.workers.remote_workers():
                 e.set_weights.remote(weights)
     return info
Ejemplo n.º 16
0
    def step(self) -> None:
        with self.queue_timer:
            try:
                batch, _ = self.minibatch_buffer.get()
            except queue.Empty:
                return

        with self.grad_timer:
            fetches = self.local_worker.learn_on_batch(batch)
            self.weights_updated = True
            self.stats = get_learner_stats(fetches)

        self.num_steps += 1
        self.outqueue.put((batch.count, self.stats))
        self.learner_queue_size.push(self.inqueue.qsize())
Ejemplo n.º 17
0
 def step(self):
     with self.queue_timer:
         ra, replay = self.inqueue.get()
     if replay is not None:
         prio_dict = {}
         with self.grad_timer:
             grad_out = self.local_worker.learn_on_batch(replay)
             for pid, info in grad_out.items():
                 prio_dict[pid] = (
                     replay.policy_batches[pid].data.get("batch_indexes"),
                     info.get("td_error"))
                 self.stats[pid] = get_learner_stats(info)
         self.outqueue.put((ra, prio_dict, replay.count))
     self.learner_queue_size.push(self.inqueue.qsize())
     self.weights_updated = True
Ejemplo n.º 18
0
    def step(self) -> Optional[_NextValueNotReady]:
        with self.queue_timer:
            try:
                batch, _ = self.minibatch_buffer.get()
            except queue.Empty:
                time.sleep(0.001)
                return _NextValueNotReady()

        with self.grad_timer:
            fetches = self.local_worker.learn_on_batch(batch)
            self.weights_updated = True
            self.stats = get_learner_stats(fetches)

        self.num_steps += 1
        self.outqueue.put((batch.count, self.stats))
        self.learner_queue_size.push(self.inqueue.qsize())
Ejemplo n.º 19
0
    def step(self) -> None:
        assert self.loader_thread.is_alive()
        with self.load_wait_timer:
            buffer_idx, released = self.ready_tower_stacks_buffer.get()

        with self.grad_timer:
            fetches = self.policy.learn_on_loaded_batch(
                offset=0, buffer_index=buffer_idx)
            self.weights_updated = True
            self.stats = {DEFAULT_POLICY_ID: get_learner_stats(fetches)}

        if released:
            self.idle_tower_stacks.put(buffer_idx)

        self.outqueue.put(
            (self.policy.get_num_samples_loaded_into_buffer(buffer_idx),
             self.stats))
        self.learner_queue_size.push(self.inqueue.qsize())
Ejemplo n.º 20
0
    def _optimize(self):
        samples = self._replay()

        with self.grad_timer:
            info_dict = self.local_evaluator.learn_on_batch(samples)
            for policy_id, info in info_dict.items():
                self.learner_stats[policy_id] = get_learner_stats(info)
                replay_buffer = self.replay_buffers[policy_id]
                if isinstance(replay_buffer, PrioritizedReplayBuffer):
                    td_error = info["td_error"]
                    new_priorities = (np.abs(td_error) +
                                      self.prioritized_replay_eps)
                    replay_buffer.update_priorities(
                        samples.policy_batches[policy_id]["batch_indexes"],
                        new_priorities)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_trained += samples.count
Ejemplo n.º 21
0
    def _train(self):
        start_samples = self.sample_until_learning_starts()

        worker = self.workers.local_worker()
        policy = worker.get_policy()
        stats = {}
        while not self._iteration_done():
            samples = worker.sample()
            self.tracker.num_steps_sampled += samples.count
            for row in samples.rows():
                self.replay.add(row)
            stats.update(policy.get_exploration_info())

            self._before_replay_steps(policy)
            for _ in range(samples.count):
                batch = self.replay.sample(self.config["train_batch_size"])
                stats = get_learner_stats(policy.learn_on_batch(batch))
                self.tracker.num_steps_trained += batch.count

        self.tracker.num_steps_sampled += start_samples
        return self._log_metrics(stats)
Ejemplo n.º 22
0
 def step(self):
     with self.overall_timer:
         with self.queue_timer:
             ra, replay = self.inqueue.get()
         if replay is not None:
             prio_dict = {}
             with self.grad_timer:
                 grad_out = self.local_worker.learn_on_batch(replay)
                 for pid, info in grad_out.items():
                     td_error = info.get(
                         "td_error",
                         info[LEARNER_STATS_KEY].get("td_error"))
                     prio_dict[pid] = (replay.policy_batches[pid].data.get(
                         "batch_indexes"), td_error)
                     self.stats[pid] = get_learner_stats(info)
                 self.grad_timer.push_units_processed(replay.count)
             self.outqueue.put((ra, prio_dict, replay.count))
         self.learner_queue_size.push(self.inqueue.qsize())
         self.weights_updated = True
         self.overall_timer.push_units_processed(replay and replay.count
                                                 or 0)
Ejemplo n.º 23
0
    def step(self):
        weights = ray.put(self.workers.local_worker().get_weights())
        pending_gradients = {}
        num_gradients = 0

        # Kick off the first wave of async tasks
        for e in self.workers.remote_workers():
            e.set_weights.remote(weights)
            future = e.compute_gradients.remote(e.sample.remote())
            pending_gradients[future] = e
            num_gradients += 1

        while pending_gradients:
            with self.wait_timer:
                wait_results = ray.wait(list(pending_gradients.keys()),
                                        num_returns=1)
                ready_list = wait_results[0]
                future = ready_list[0]

                gradient, info = ray_get_and_free(future)
                e = pending_gradients.pop(future)
                self.learner_stats = get_learner_stats(info)

            if gradient is not None:
                with self.apply_timer:
                    self.workers.local_worker().apply_gradients(gradient)
                self.num_steps_sampled += info["batch_count"]
                self.num_steps_trained += info["batch_count"]

            if num_gradients < self.grads_per_step:
                with self.dispatch_timer:
                    e.set_weights.remote(
                        self.workers.local_worker().get_weights())
                    future = e.compute_gradients.remote(e.sample.remote())

                    pending_gradients[future] = e
                    num_gradients += 1
Ejemplo n.º 24
0
    def __call__(self, data_tuple):
        """Args:
        data_tuple: 1st element is samples collected from MAML
        Inner adaptation steps and 2nd element is accumulated metrics
        """
        # Metaupdate Step.
        print("Meta-Update Step")
        samples = data_tuple[0]
        adapt_metrics_dict = data_tuple[1]
        self.postprocess_metrics(adapt_metrics_dict,
                                 prefix="MAMLIter{}".format(self.step_counter))

        # MAML Meta-update.
        fetches = None
        for i in range(self.maml_optimizer_steps):
            fetches = self.workers.local_worker().learn_on_batch(samples)
        learner_stats = get_learner_stats(fetches)

        # Update KLs.
        def update(pi, pi_id):
            assert "inner_kl" not in learner_stats, (
                "inner_kl should be nested under policy id key",
                learner_stats,
            )
            if pi_id in learner_stats:
                assert "inner_kl" in learner_stats[pi_id], (learner_stats,
                                                            pi_id)
                pi.update_kls(learner_stats[pi_id]["inner_kl"])
            else:
                logger.warning("No data for {}, not updating kl".format(pi_id))

        self.workers.local_worker().foreach_policy_to_train(update)

        # Modify Reporting Metrics.
        metrics = _get_shared_metrics()
        metrics.info[LEARNER_INFO] = fetches
        metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count
        metrics.counters[STEPS_TRAINED_COUNTER] += samples.count

        if self.step_counter == self.num_steps - 1:
            td_metric = self.workers.local_worker().foreach_policy(
                fit_dynamics)[0]

            # Sync workers with meta policy.
            self.workers.sync_weights()

            # Sync TD Models with workers.
            sync_ensemble(self.workers)
            sync_stats(self.workers)

            metrics.counters[STEPS_SAMPLED_COUNTER] = td_metric[
                STEPS_SAMPLED_COUNTER]

            # Modify to CollectMetrics.
            res = self.metric_gen.__call__(None)
            res.update(self.metrics)
            self.step_counter = 0
            print("MB-MPO Iteration Completed")
            return [res]
        else:
            print("MAML Iteration {} Completed".format(self.step_counter))
            self.step_counter += 1

            # Sync workers with meta policy
            print("Syncing Weights with Workers")
            self.workers.sync_weights()
            return []