def __call__(self, samples: SampleBatchType): _check_sample_batch_type(samples) metrics = LocalIterator.get_metrics() with metrics.timers[COMPUTE_GRADS_TIMER]: grad, info = self.workers.local_worker().compute_gradients(samples) metrics.info[LEARNER_INFO] = get_learner_stats(info) return grad, samples.count
def step(self): with self.overall_timer: with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: grad_out = self.local_worker.learn_on_batch(replay) for pid, info in grad_out.items(): td_error = info.get( "td_error", info[LEARNER_STATS_KEY].get("td_error")) # Switch off auto-conversion from numpy to torch/tf # tensors for the indices. This may lead to errors # when sent to the buffer for processing # (may get manipulated if they are part of a tensor). replay.policy_batches[pid].set_get_interceptor(None) prio_dict[pid] = ( replay.policy_batches[pid].get("batch_indexes"), td_error) self.stats[pid] = get_learner_stats(info) self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True self.overall_timer.push_units_processed(replay and replay.count or 0)
def _optimize(self): samples = self._replay() with self.grad_timer: if self.before_learn_on_batch: samples = self.before_learn_on_batch( samples, self.workers.local_worker().policy_map, self.train_batch_size) info_dict = self.workers.local_worker().learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) replay_buffer = self.replay_buffers[policy_id] if isinstance(replay_buffer, PrioritizedReplayBuffer): # TODO(sven): This is currently structured differently for # torch/tf. Clean up these results/info dicts across # policies (note: fixing this in torch_policy.py will # break e.g. DDPPO!). td_error = info.get("td_error", info["learner_stats"].get("td_error")) new_priorities = (np.abs(td_error) + self.prioritized_replay_eps) replay_buffer.update_priorities( samples.policy_batches[policy_id]["batch_indexes"], new_priorities) self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count
def __call__(self, batch: SampleBatchType) -> (SampleBatchType, List[dict]): _check_sample_batch_type(batch) metrics = _get_shared_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: if self.num_sgd_iter > 1 or self.sgd_minibatch_size > 0: w = self.workers.local_worker() info = do_minibatch_sgd( batch, {p: w.get_policy(p) for p in self.policies}, w, self.num_sgd_iter, self.sgd_minibatch_size, []) # TODO(ekl) shouldn't be returning learner stats directly here metrics.info[LEARNER_INFO] = info else: info = self.workers.local_worker().learn_on_batch(batch) metrics.info[LEARNER_INFO] = get_learner_stats(info) learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights( self.policies)) for e in self.workers.remote_workers(): e.set_weights.remote(weights, _get_global_vars()) # Also update global vars of the local worker. self.workers.local_worker().set_global_vars(_get_global_vars()) return batch, info
def improve_policy(self, num_improvements: int) -> Dict[str, float]: """Call the policy to perform policy improvement using the augmented replay. Args: num_improvements: Number of times to call `policy.learn_on_batch` Returns: A dictionary of training and exploration statistics """ policy = self.get_policy() batch_size = self.config["train_batch_size"] env_batch_size = int(batch_size * self.config["real_data_ratio"]) model_batch_size = batch_size - env_batch_size stats = {} for _ in range(num_improvements): samples = [] if env_batch_size: samples += [self.replay.sample(env_batch_size)] if model_batch_size: samples += [self.virtual_replay.sample(model_batch_size)] batch = SampleBatch.concat_samples(samples) stats = get_learner_stats(policy.learn_on_batch(batch)) self.tracker.num_steps_trained += batch.count stats.update(policy.get_exploration_info()) return stats
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.remote_evaluators: samples.extend( ray.get([ e.sample.remote() for e in self.remote_evaluators ])) else: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.learn_on_batch(samples) self.learner_stats = get_learner_stats(fetches) if self.num_sgd_iter > 1: logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches
def __call__(self, data_tuple): # Metaupdate Step samples = data_tuple[0] adapt_metrics_dict = data_tuple[1] for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) fetches = get_learner_stats(fetches) # Sync workers with meta policy self.workers.sync_weights() # Set worker tasks set_worker_tasks(self.workers) # Update KLS def update(pi, pi_id): assert "inner_kl" not in fetches, ( "inner_kl should be nested under policy id key", fetches) if pi_id in fetches: assert "inner_kl" in fetches[pi_id], (fetches, pi_id) pi.update_kls(fetches[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) self.workers.local_worker().foreach_trainable_policy(update) # Modify Reporting Metrics metrics = _get_shared_metrics() metrics.info[LEARNER_INFO] = fetches metrics.counters[STEPS_TRAINED_COUNTER] += samples.count res = self.metric_gen.__call__(None) res.update(adapt_metrics_dict) return res
def __call__(self, item): (grads, info), count = item metrics = LocalIterator.get_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += count metrics.info[LEARNER_INFO] = get_learner_stats(info) metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() - self.fetch_start_time) return grads, count
def test_update_interval(policy, model_update_interval, samples): expected = 0 for i in range(1, model_update_interval * 10 + 1): info = policy.learn_on_batch(samples) info = get_learner_stats(info) assert policy._learn_calls == i expected += 1 if (i % model_update_interval == 0 or i == 1) else 0 assert info["model_epochs"] == expected
def _sgd_step(self): samples = [random.choice(self.replay_buffer)] while sum(s.count for s in samples) < self.train_batch_size: samples.append(random.choice(self.replay_buffer)) samples = SampleBatch.concat_samples(samples) info_dict = self.workers.local_worker().learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) self.num_steps_trained += samples.count return info_dict
def step(self): with self.queue_timer: batch, _ = self.minibatch_buffer.get() with self.grad_timer: fetches = self.local_worker.learn_on_batch(batch) self.weights_updated = True self.stats = get_learner_stats(fetches) self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize())
def _optimize(self): samples = [random.choice(self.replay_buffer)] while sum(s.count for s in samples) < self.train_batch_size: samples.append(random.choice(self.replay_buffer)) samples = SampleBatch.concat_samples(samples) with self.grad_timer: info_dict = self.local_evaluator.learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count return info_dict
def step(self): with self.queue_timer: batch, _ = self.minibatch_buffer.get() with self.grad_timer: fetches = self.local_worker.learn_on_batch(batch) self.weights_updated = True self.stats.update(get_learner_stats(fetches)) self.stats["train_timesteps"] += batch.count self.num_steps += 1 self.stats["update_steps"] = self.num_steps self.outqueue.put(batch.count) self.learner_queue_size.push(self.inqueue.qsize())
def step(self): assert self.loader_thread.is_alive() with self.load_wait_timer: opt, released = self.minibatch_buffer.get() with self.grad_timer: fetches = opt.optimize(self.sess, 0) self.weights_updated = True self.stats = get_learner_stats(fetches) if released: self.idle_optimizers.put(opt) self.outqueue.put(opt.num_tuples_loaded) self.learner_queue_size.push(self.inqueue.qsize())
def __call__(self, batch: SampleBatchType) -> List[dict]: _check_sample_batch_type(batch) metrics = LocalIterator.get_metrics() learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER] with learn_timer: info = self.workers.local_worker().learn_on_batch(batch) learn_timer.push_units_processed(batch.count) metrics.counters[STEPS_TRAINED_COUNTER] += batch.count metrics.info[LEARNER_INFO] = get_learner_stats(info) if self.workers.remote_workers(): with metrics.timers[WORKER_UPDATE_TIMER]: weights = ray.put(self.workers.local_worker().get_weights()) for e in self.workers.remote_workers(): e.set_weights.remote(weights) return info
def step(self) -> None: with self.queue_timer: try: batch, _ = self.minibatch_buffer.get() except queue.Empty: return with self.grad_timer: fetches = self.local_worker.learn_on_batch(batch) self.weights_updated = True self.stats = get_learner_stats(fetches) self.num_steps += 1 self.outqueue.put((batch.count, self.stats)) self.learner_queue_size.push(self.inqueue.qsize())
def step(self): with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: grad_out = self.local_worker.learn_on_batch(replay) for pid, info in grad_out.items(): prio_dict[pid] = ( replay.policy_batches[pid].data.get("batch_indexes"), info.get("td_error")) self.stats[pid] = get_learner_stats(info) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True
def step(self) -> Optional[_NextValueNotReady]: with self.queue_timer: try: batch, _ = self.minibatch_buffer.get() except queue.Empty: time.sleep(0.001) return _NextValueNotReady() with self.grad_timer: fetches = self.local_worker.learn_on_batch(batch) self.weights_updated = True self.stats = get_learner_stats(fetches) self.num_steps += 1 self.outqueue.put((batch.count, self.stats)) self.learner_queue_size.push(self.inqueue.qsize())
def step(self) -> None: assert self.loader_thread.is_alive() with self.load_wait_timer: buffer_idx, released = self.ready_tower_stacks_buffer.get() with self.grad_timer: fetches = self.policy.learn_on_loaded_batch( offset=0, buffer_index=buffer_idx) self.weights_updated = True self.stats = {DEFAULT_POLICY_ID: get_learner_stats(fetches)} if released: self.idle_tower_stacks.put(buffer_idx) self.outqueue.put( (self.policy.get_num_samples_loaded_into_buffer(buffer_idx), self.stats)) self.learner_queue_size.push(self.inqueue.qsize())
def _optimize(self): samples = self._replay() with self.grad_timer: info_dict = self.local_evaluator.learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) replay_buffer = self.replay_buffers[policy_id] if isinstance(replay_buffer, PrioritizedReplayBuffer): td_error = info["td_error"] new_priorities = (np.abs(td_error) + self.prioritized_replay_eps) replay_buffer.update_priorities( samples.policy_batches[policy_id]["batch_indexes"], new_priorities) self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count
def _train(self): start_samples = self.sample_until_learning_starts() worker = self.workers.local_worker() policy = worker.get_policy() stats = {} while not self._iteration_done(): samples = worker.sample() self.tracker.num_steps_sampled += samples.count for row in samples.rows(): self.replay.add(row) stats.update(policy.get_exploration_info()) self._before_replay_steps(policy) for _ in range(samples.count): batch = self.replay.sample(self.config["train_batch_size"]) stats = get_learner_stats(policy.learn_on_batch(batch)) self.tracker.num_steps_trained += batch.count self.tracker.num_steps_sampled += start_samples return self._log_metrics(stats)
def step(self): with self.overall_timer: with self.queue_timer: ra, replay = self.inqueue.get() if replay is not None: prio_dict = {} with self.grad_timer: grad_out = self.local_worker.learn_on_batch(replay) for pid, info in grad_out.items(): td_error = info.get( "td_error", info[LEARNER_STATS_KEY].get("td_error")) prio_dict[pid] = (replay.policy_batches[pid].data.get( "batch_indexes"), td_error) self.stats[pid] = get_learner_stats(info) self.grad_timer.push_units_processed(replay.count) self.outqueue.put((ra, prio_dict, replay.count)) self.learner_queue_size.push(self.inqueue.qsize()) self.weights_updated = True self.overall_timer.push_units_processed(replay and replay.count or 0)
def step(self): weights = ray.put(self.workers.local_worker().get_weights()) pending_gradients = {} num_gradients = 0 # Kick off the first wave of async tasks for e in self.workers.remote_workers(): e.set_weights.remote(weights) future = e.compute_gradients.remote(e.sample.remote()) pending_gradients[future] = e num_gradients += 1 while pending_gradients: with self.wait_timer: wait_results = ray.wait(list(pending_gradients.keys()), num_returns=1) ready_list = wait_results[0] future = ready_list[0] gradient, info = ray_get_and_free(future) e = pending_gradients.pop(future) self.learner_stats = get_learner_stats(info) if gradient is not None: with self.apply_timer: self.workers.local_worker().apply_gradients(gradient) self.num_steps_sampled += info["batch_count"] self.num_steps_trained += info["batch_count"] if num_gradients < self.grads_per_step: with self.dispatch_timer: e.set_weights.remote( self.workers.local_worker().get_weights()) future = e.compute_gradients.remote(e.sample.remote()) pending_gradients[future] = e num_gradients += 1
def __call__(self, data_tuple): """Args: data_tuple: 1st element is samples collected from MAML Inner adaptation steps and 2nd element is accumulated metrics """ # Metaupdate Step. print("Meta-Update Step") samples = data_tuple[0] adapt_metrics_dict = data_tuple[1] self.postprocess_metrics(adapt_metrics_dict, prefix="MAMLIter{}".format(self.step_counter)) # MAML Meta-update. fetches = None for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) learner_stats = get_learner_stats(fetches) # Update KLs. def update(pi, pi_id): assert "inner_kl" not in learner_stats, ( "inner_kl should be nested under policy id key", learner_stats, ) if pi_id in learner_stats: assert "inner_kl" in learner_stats[pi_id], (learner_stats, pi_id) pi.update_kls(learner_stats[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) self.workers.local_worker().foreach_policy_to_train(update) # Modify Reporting Metrics. metrics = _get_shared_metrics() metrics.info[LEARNER_INFO] = fetches metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count metrics.counters[STEPS_TRAINED_COUNTER] += samples.count if self.step_counter == self.num_steps - 1: td_metric = self.workers.local_worker().foreach_policy( fit_dynamics)[0] # Sync workers with meta policy. self.workers.sync_weights() # Sync TD Models with workers. sync_ensemble(self.workers) sync_stats(self.workers) metrics.counters[STEPS_SAMPLED_COUNTER] = td_metric[ STEPS_SAMPLED_COUNTER] # Modify to CollectMetrics. res = self.metric_gen.__call__(None) res.update(self.metrics) self.step_counter = 0 print("MB-MPO Iteration Completed") return [res] else: print("MAML Iteration {} Completed".format(self.step_counter)) self.step_counter += 1 # Sync workers with meta policy print("Syncing Weights with Workers") self.workers.sync_weights() return []