def _postprocess_dqn(policy_graph, sample_batch): obs, actions, rewards, new_obs, dones = [ list(x) for x in sample_batch.columns( ["obs", "actions", "rewards", "new_obs", "dones"]) ] # N-step Q adjustments if policy_graph.config["n_step"] > 1: _adjust_nstep(policy_graph.config["n_step"], policy_graph.config["gamma"], obs, actions, rewards, new_obs, dones) batch = SampleBatch({ "obs": obs, "actions": actions, "rewards": rewards, "new_obs": new_obs, "dones": dones, "weights": np.ones_like(rewards) }) # Prioritize on the worker side if batch.count > 0 and policy_graph.config["worker_side_prioritization"]: td_errors = policy_graph.compute_td_error( batch["obs"], batch["actions"], batch["rewards"], batch["new_obs"], batch["dones"], batch["weights"]) new_priorities = ( np.abs(td_errors) + policy_graph.config["prioritized_replay_eps"]) batch.data["weights"] = new_priorities return batch
def _from_json(batch): if isinstance(batch, bytes): # smart_open S3 doesn't respect "r" batch = batch.decode("utf-8") data = json.loads(batch) if "type" in data: data_type = data.pop("type") else: raise ValueError("JSON record missing 'type' field") if data_type == "SampleBatch": for k, v in data.items(): data[k] = unpack_if_needed(v) return SampleBatch(data) elif data_type == "MultiAgentBatch": policy_batches = {} for policy_id, policy_batch in data["policy_batches"].items(): inner = {} for k, v in policy_batch.items(): inner[k] = unpack_if_needed(v) policy_batches[policy_id] = SampleBatch(inner) return MultiAgentBatch(policy_batches, data["count"]) else: raise ValueError( "Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type)
def _postprocess_dqn(policy_graph, sample_batch): obs, actions, rewards, new_obs, dones = [ list(x) for x in sample_batch.columns( ["obs", "actions", "rewards", "new_obs", "dones"]) ] # N-step Q adjustments if policy_graph.config["n_step"] > 1: adjust_nstep(policy_graph.config["n_step"], policy_graph.config["gamma"], obs, actions, rewards, new_obs, dones) batch = SampleBatch({ "obs": obs, "actions": actions, "rewards": rewards, "new_obs": new_obs, "dones": dones, "weights": np.ones_like(rewards) }) # Prioritize on the worker side if batch.count > 0 and policy_graph.config["worker_side_prioritization"]: td_errors = policy_graph.compute_td_error( batch["obs"], batch["actions"], batch["rewards"], batch["new_obs"], batch["dones"], batch["weights"]) new_priorities = ( np.abs(td_errors) + policy_graph.config["prioritized_replay_eps"]) batch.data["weights"] = new_priorities return batch
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.remote_evaluators: samples.extend( ray.get([ e.sample.remote() for e in self.remote_evaluators ])) else: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.learn_on_batch(samples) if "stats" in fetches: self.learner_stats = fetches["stats"] if self.num_sgd_iter > 1: logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches
def _from_json(batch): if isinstance(batch, bytes): # smart_open S3 doesn't respect "r" batch = batch.decode("utf-8") data = json.loads(batch) for k, v in data.items(): data[k] = [unpack_if_needed(x) for x in unpack_if_needed(v)] return SampleBatch(data)
def collect_samples(agents, timesteps_per_batch): num_timesteps_so_far = 0 trajectories = [] # This variable maps the object IDs of trajectories that are currently # computed to the agent that they are computed on; we start some initial # tasks here. agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while num_timesteps_so_far < timesteps_per_batch: # TODO(pcm): Make wait support arbitrary iterators and remove the # conversion to list here. [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) # Start task with next trajectory and record it in the dictionary. fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent next_sample = ray.get(fut_sample) num_timesteps_so_far += next_sample.count trajectories.append(next_sample) return SampleBatch.concat_samples(trajectories)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: samples = SampleBatch.concat_samples( ray.get( [e.sample.remote() for e in self.remote_evaluators])) else: samples = self.local_evaluator.sample() with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.compute_apply(samples) if self.num_sgd_iter > 1: print(i, fetches) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches
def _replay(self): samples = {} with self.replay_timer: for policy_id, replay_buffer in self.replay_buffers.items(): if isinstance(replay_buffer, PrioritizedReplayBuffer): (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample( self.train_batch_size, beta=self.prioritized_replay_beta.value( self.num_steps_trained)) else: (obses_t, actions, rewards, obses_tp1, dones) = replay_buffer.sample(self.train_batch_size) weights = np.ones_like(rewards) batch_indexes = -np.ones_like(rewards) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) return MultiAgentBatch(samples, self.train_batch_size)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batch = SampleBatch.concat_samples( ray.get( [e.sample.remote() for e in self.remote_evaluators])) else: batch = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch}, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], weight=None) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: batch = SampleBatch.concat_samples( ray.get( [e.sample.remote() for e in self.remote_evaluators])) else: batch = self.local_evaluator.sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): batch = MultiAgentBatch({ DEFAULT_POLICY_ID: batch }, batch.count) for policy_id, s in batch.policy_batches.items(): for row in s.rows(): self.replay_buffers[policy_id].add( pack_if_needed(row["obs"]), row["actions"], row["rewards"], pack_if_needed(row["new_obs"]), row["dones"], weight=None) if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count
def update_episode_buffer(self, samples): if self.config["episode_mode"] == "episode_buffer": # assert True new_batches = list(separate_sample_batch(samples).values()) if self._episode_buffer: old_batches = \ list(separate_sample_batch(self._episode_buffer).values()) else: old_batches = [] all_batches = old_batches + new_batches all_batches = sorted(all_batches, key=lambda x: x["rewards"].sum(), reverse=True) buffer_size = self.config["buffer_size"] # assert buffer_size == 24 self._episode_buffer = SampleBatch.concat_samples( all_batches[:buffer_size]) elif self.config["episode_mode"] == "last_episodes": # assert False self._episode_buffer = samples elif self.config["episode_mode"] == "all_episodes": # assert False if self._episode_buffer: self._episode_buffer = self._episode_buffer.concat(samples) else: self._episode_buffer = samples else: raise NotImplementedError self._rnn_state_out = None
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.remote_evaluators: samples.extend( ray.get([ e.sample.remote() for e in self.remote_evaluators ])) else: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) with self.grad_timer: for i in range(self.num_sgd_iter): fetches = self.local_evaluator.compute_apply(samples) if "stats" in fetches: self.learner_stats = fetches["stats"] if self.num_sgd_iter > 1: logger.debug("{} {}".format(i, fetches)) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count return fetches
def collect_samples_straggler_mitigation(agents, train_batch_size): """Collects at least train_batch_size samples. This is the legacy behavior as of 0.6, and launches extra sample tasks to potentially improve performance but can result in many wasted samples. """ num_timesteps_so_far = 0 trajectories = [] agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while num_timesteps_so_far < train_batch_size: # TODO(pcm): Make wait support arbitrary iterators and remove the # conversion to list here. [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) # Start task with next trajectory and record it in the dictionary. fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent next_sample = ray.get(fut_sample) num_timesteps_so_far += next_sample.count trajectories.append(next_sample) logger.info("Discarding {} sample tasks".format(len(agent_dict))) return SampleBatch.concat_samples(trajectories)
def collect_samples(agents, sample_batch_size, num_envs_per_worker, train_batch_size): """Collects at least train_batch_size samples, never discarding any.""" num_timesteps_so_far = 0 trajectories = [] agent_dict = {} for agent in agents: fut_sample = agent.sample.remote() agent_dict[fut_sample] = agent while agent_dict: [fut_sample], _ = ray.wait(list(agent_dict)) agent = agent_dict.pop(fut_sample) next_sample = ray.get(fut_sample) assert next_sample.count >= sample_batch_size * num_envs_per_worker num_timesteps_so_far += next_sample.count trajectories.append(next_sample) # Only launch more tasks if we don't already have enough pending pending = len(agent_dict) * sample_batch_size * num_envs_per_worker if num_timesteps_so_far + pending < train_batch_size: fut_sample2 = agent.sample.remote() agent_dict[fut_sample2] = agent return SampleBatch.concat_samples(trajectories)
def testBatchIds(self): ev = PolicyEvaluator(env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph) batch1 = ev.sample() batch2 = ev.sample() self.assertEqual(len(set(batch1["unroll_id"])), 1) self.assertEqual(len(set(batch2["unroll_id"])), 1) self.assertEqual( len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
def postprocess_trajectory(samples, baseline, gamma, lambda_, use_gae): separated_samples = separate_sample_batch(samples) baseline.fit(separated_samples.values()) for eps_id, values in separated_samples.items(): values["vf_preds"] = baseline.predict(values) separated_samples[eps_id] = compute_advantages(values, 0.0, gamma, lambda_, use_gae) samples = SampleBatch.concat_samples(list(separated_samples.values())) return samples
def build_and_reset(self): """Returns a sample batch including all previously added values.""" batch = SampleBatch( {k: to_float_array(v) for k, v in self.buffers.items()}) self.buffers.clear() self.count = 0 return batch
def build_and_reset(self): """Returns a sample batch including all previously added values.""" batch = SampleBatch( {k: to_float_array(v) for k, v in self.buffers.items()}) batch.data[SampleBatch.UNROLL_ID] = np.repeat(self.unroll_id, batch.count) self.buffers.clear() self.count = 0 self.unroll_id += 1 return batch
def _optimize(self): samples = [random.choice(self.replay_buffer)] while sum(s.count for s in samples) < self.train_batch_size: samples.append(random.choice(self.replay_buffer)) samples = SampleBatch.concat_samples(samples) with self.grad_timer: info_dict = self.local_evaluator.compute_apply(samples) for policy_id, info in info_dict.items(): if "stats" in info: self.learner_stats[policy_id] = info["stats"] self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count
def separate_sample_batch(sample_batch): separated_sample_batch = defaultdict(lambda: defaultdict(list)) for i, eps_id in enumerate(sample_batch["eps_id"]): for key in sample_batch.keys(): separated_sample_batch[eps_id][key].append(sample_batch[key][i]) for eps_id, values in separated_sample_batch.items(): for k, v in values.items(): separated_sample_batch[eps_id][k] = np.stack(v) separated_sample_batch[eps_id] = SampleBatch( dict(separated_sample_batch[eps_id])) separated_sample_batch = dict(separated_sample_batch) return separated_sample_batch
def compute_advantages(rollout, gamma=1, modify=False): """Given a rollout, compute its value targets and the advantage. Args: rollout (SampleBatch): SampleBatch of a single trajectory last_r (float): Value estimation for last observation gamma (float): Discount factor. lambda_ (float): Parameter for GAE use_gae (bool): Using Generalized Advantage Estamation Returns: SampleBatch (SampleBatch): Object with experience from rollout and processed rewards. """ traj = {} trajsize = len(rollout["actions"]) for key in rollout: traj[key] = np.stack(rollout[key]) rewards = traj['rewards'] gammas = np.power(gamma, np.arange(trajsize)) cum_ret_t = np.zeros(trajsize) for t in range(trajsize): if t == 0: cum_ret_t[t] = np.cumprod(1 + rewards * gammas)[-1] else: cum_ret_t[t] = np.cumprod(1 + rewards[t:] * gammas[:-t])[-1] cum_ret_t -= 1 if modify: cum_ret_t[(-0.01 < cum_ret_t) & (cum_ret_t <= 0)] = -0.01 cum_ret_t *= 1000 if 'vf_preds' in traj: traj["advantages"] = cum_ret_t - traj['vf_preds'] traj["value_targets"] = (traj["advantages"] + traj["vf_preds"]).copy().astype(np.float32) else: traj["advantages"] = cum_ret_t traj["value_targets"] = traj["value_targets"] = np.zeros_like( traj["advantages"]) traj["advantages"] = traj["advantages"].copy().astype(np.float32) assert all(val.shape[0] == trajsize for val in traj.values()), \ "Rollout stacked incorrectly!" return SampleBatch(traj)
def compute_returns(rollout, last_r, gamma): traj = {} trajsize = len(rollout["actions"]) for key in rollout: traj[key] = np.stack(rollout[key]) rewards_plus_v = np.concatenate([rollout["rewards"], np.array([last_r])]) traj["returns"] = discount(rewards_plus_v, gamma)[:-1] traj["returns"] = traj["returns"].copy().astype(np.float32) assert all(val.shape[0] == trajsize for val in traj.values()), \ "Rollout stacked incorrectly!" return SampleBatch(traj)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: if self.remote_evaluators: samples.extend( ray.get([ e.sample.remote() for e in self.remote_evaluators ])) else: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) # print("\n\nhkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples", samples.keys()) # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples obs", samples["obs"].shape) # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples new_obs", samples["new_obs"].shape) # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples actions", samples["actions"].shape) # import numpy # transition_state = numpy.stack([samples["obs"], samples["new_obs"]], axis=2) # print("hkkkkk ", transition_state.shape) # print("hkkkkk ", transition_state.reshape(transition_state.shape[0], -1).shape) # print("samples[obs]", samples["obs"]) # print("samples[actions]", samples["actions"]) with self.grad_timer: new_samples = self.env_model.process(samples) # print("new_samples[obs]", new_samples["obs"]) # print("new_samples[actions]", new_samples["actions"]) for i in range(self.num_sgd_iter): fetches = self.local_evaluator.compute_apply(new_samples) if "stats" in fetches: self.learner_stats = fetches["stats"] if self.num_sgd_iter > 1: print(i, fetches) # self.grad_timer.push_units_processed(new_samples.count) self.grad_timer.push_units_processed(len(samples["obs"])) # self.num_steps_sampled += new_samples.count # self.num_steps_trained += new_samples.count self.num_steps_sampled += len(samples["obs"]) self.num_steps_trained += len(samples["obs"]) return fetches
def _postprocess_if_needed(self, batch): if not self.ioctx.config.get("postprocess_inputs"): return batch if isinstance(batch, SampleBatch): out = [] for sub_batch in batch.split_by_episode(): out.append(self.ioctx.evaluator.policy_map[DEFAULT_POLICY_ID] .postprocess_trajectory(sub_batch)) return SampleBatch.concat_samples(out) else: # TODO(ekl) this is trickier since the alignments between agent # trajectories in the episode are not available any more. raise NotImplementedError( "Postprocessing of multi-agent data not implemented yet.")
def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True): """Given a rollout, compute its value targets and the advantage. Args: rollout (SampleBatch): SampleBatch of a single trajectory last_r (float): Value estimation for last observation gamma (float): Discount factor. lambda_ (float): Parameter for GAE use_gae (bool): Using Generalized Advantage Estamation Returns: SampleBatch (SampleBatch): Object with experience from rollout and processed rewards. """ traj = {} trajsize = len(rollout[SampleBatch.ACTIONS]) for key in rollout: traj[key] = np.stack(rollout[key]) if use_gae: assert SampleBatch.VF_PREDS in rollout, "Values not found!" vpred_t = np.concatenate( [rollout[SampleBatch.VF_PREDS], np.array([last_r])]) delta_t = (traj[SampleBatch.REWARDS] + gamma * vpred_t[1:] - vpred_t[:-1]) # This formula for the advantage comes # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438 traj[Postprocessing.ADVANTAGES] = discount(delta_t, gamma * lambda_) traj[Postprocessing.VALUE_TARGETS] = ( traj[Postprocessing.ADVANTAGES] + traj[SampleBatch.VF_PREDS]).copy().astype(np.float32) else: rewards_plus_v = np.concatenate( [rollout[SampleBatch.REWARDS], np.array([last_r])]) traj[Postprocessing.ADVANTAGES] = discount(rewards_plus_v, gamma)[:-1] # TODO(ekl): support using a critic without GAE traj[Postprocessing.VALUE_TARGETS] = np.zeros_like( traj[Postprocessing.ADVANTAGES]) traj[Postprocessing.ADVANTAGES] = traj[ Postprocessing.ADVANTAGES].copy().astype(np.float32) assert all(val.shape[0] == trajsize for val in traj.values()), \ "Rollout stacked incorrectly!" return SampleBatch(traj)
def replay(self): with self.replay_timer: if len(self.replay_buffer) < self.replay_starts: return None (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = self.replay_buffer.sample( self.train_batch_size, beta=self.prioritized_replay_beta) batch = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) return batch
def replay(self): if self.num_added < self.replay_starts: return None with self.replay_timer: samples = {} for policy_id, replay_buffer in self.replay_buffers.items(): (obses_t, actions, rewards, obses_tp1, dones, weights, batch_indexes) = replay_buffer.sample( self.train_batch_size, beta=self.prioritized_replay_beta) samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, "rewards": rewards, "new_obs": obses_tp1, "dones": dones, "weights": weights, "batch_indexes": batch_indexes }) return MultiAgentBatch(samples, self.train_batch_size)
def compute_advantages(rollout, last_r, gamma, lambda_=1.0, use_gae=True): """Given a rollout, compute its value targets and the advantage. Args: rollout (PartialRollout): Partial Rollout Object last_r (float): Value estimation for last observation gamma (float): Parameter for GAE lambda_ (float): Parameter for GAE use_gae (bool): Using Generalized Advantage Estamation Returns: SampleBatch (SampleBatch): Object with experience from rollout and processed rewards. """ traj = {} trajsize = len(rollout["actions"]) for key in rollout: traj[key] = np.stack(rollout[key]) if use_gae: assert "vf_preds" in rollout, "Values not found!" vpred_t = np.concatenate([rollout["vf_preds"], np.array([last_r])]) delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1] # This formula for the advantage comes # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438 traj["advantages"] = discount(delta_t, gamma * lambda_) traj["value_targets"] = (traj["advantages"] + traj["vf_preds"]).copy().astype(np.float32) else: rewards_plus_v = np.concatenate( [rollout["rewards"], np.array([last_r])]) traj["advantages"] = discount(rewards_plus_v, gamma)[:-1] traj["advantages"] = traj["advantages"].copy().astype(np.float32) assert all(val.shape[0] == trajsize for val in traj.values()), \ "Rollout stacked incorrectly!" return SampleBatch(traj)
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: samples = SampleBatch.concat_samples( ray.get( [e.sample.remote() for e in self.remote_evaluators])) else: samples = self.local_evaluator.sample() with self.grad_timer: grad, _ = self.local_evaluator.compute_gradients(samples) self.local_evaluator.apply_gradients(grad) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count self.num_steps_trained += samples.count
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: if self.straggler_mitigation: samples = collect_samples_straggler_mitigation( self.remote_evaluators, self.train_batch_size) else: samples = collect_samples( self.remote_evaluators, self.sample_batch_size, self.num_envs_per_worker, self.train_batch_size) if samples.count > self.train_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, train_batch_size={}). ".format( samples.count, self.train_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: samples = [] while sum(s.count for s in samples) < self.train_batch_size: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({ DEFAULT_POLICY_ID: samples }, samples.count) for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized # Important: don't shuffle RNN sequence elements if not policy._state_inputs: batch.shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): if policy_id not in self.policies: continue policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches.items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += tuples_per_device * len(self.devices) return fetches
def step(self): with self.update_weights_timer: if self.remote_evaluators: weights = ray.put(self.local_evaluator.get_weights()) for e in self.remote_evaluators: e.set_weights.remote(weights) with self.sample_timer: if self.remote_evaluators: if self.straggler_mitigation: samples = collect_samples_straggler_mitigation( self.remote_evaluators, self.train_batch_size) else: samples = collect_samples( self.remote_evaluators, self.sample_batch_size, self.num_envs_per_worker, self.train_batch_size) if samples.count > self.train_batch_size * 2: logger.info( "Collected more training samples than expected " "(actual={}, train_batch_size={}). ".format( samples.count, self.train_batch_size) + "This may be because you have many workers or " "long episodes in 'complete_episodes' batch mode.") else: samples = [] while sum(s.count for s in samples) < self.train_batch_size: samples.append(self.local_evaluator.sample()) samples = SampleBatch.concat_samples(samples) # Handle everything as if multiagent if isinstance(samples, SampleBatch): samples = MultiAgentBatch({ DEFAULT_POLICY_ID: samples }, samples.count) for policy_id, policy in self.policies.items(): if policy_id not in samples.policy_batches: continue batch = samples.policy_batches[policy_id] for field in self.standardize_fields: value = batch[field] standardized = (value - value.mean()) / max(1e-4, value.std()) batch[field] = standardized # Important: don't shuffle RNN sequence elements if not policy._state_inputs: batch.shuffle() num_loaded_tuples = {} with self.load_timer: for policy_id, batch in samples.policy_batches.items(): if policy_id not in self.policies: continue policy = self.policies[policy_id] tuples = policy._get_loss_inputs_dict(batch) data_keys = [ph for _, ph in policy._loss_inputs] if policy._state_inputs: state_keys = policy._state_inputs + [policy._seq_lens] else: state_keys = [] num_loaded_tuples[policy_id] = ( self.optimizers[policy_id].load_data( self.sess, [tuples[k] for k in data_keys], [tuples[k] for k in state_keys])) fetches = {} with self.grad_timer: for policy_id, tuples_per_device in num_loaded_tuples.items(): optimizer = self.optimizers[policy_id] num_batches = max( 1, int(tuples_per_device) // int(self.per_device_batch_size)) logger.debug("== sgd epochs for {} ==".format(policy_id)) for i in range(self.num_sgd_iter): iter_extra_fetches = defaultdict(list) permutation = np.random.permutation(num_batches) for batch_index in range(num_batches): batch_fetches = optimizer.optimize( self.sess, permutation[batch_index] * self.per_device_batch_size) for k, v in batch_fetches[LEARNER_STATS_KEY].items(): iter_extra_fetches[k].append(v) logger.debug("{} {}".format(i, _averaged(iter_extra_fetches))) fetches[policy_id] = _averaged(iter_extra_fetches) self.num_steps_sampled += samples.count self.num_steps_trained += tuples_per_device * len(self.devices) self.learner_stats = fetches return fetches
def _initialize_loss(self): def fake_array(tensor): shape = tensor.shape.as_list() shape[0] = 1 return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), SampleBatch.ACTIONS: fake_array(self._prev_action_input), SampleBatch.REWARDS: np.array([0], dtype=np.float32), SampleBatch.DONES: np.array([False], dtype=np.bool), } state_init = self.get_initial_state() for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) if state_init: dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) for k, v in self.extra_compute_action_fetches().items(): dummy_batch[k] = fake_array(v) # postprocessing might depend on variable init, so run it first here self._sess.run(tf.global_variables_initializer()) postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) batch_tensors = UsageTrackingDict({ SampleBatch.PREV_ACTIONS: self._prev_action_input, SampleBatch.PREV_REWARDS: self._prev_reward_input, SampleBatch.CUR_OBS: self._obs_input, }) loss_inputs = [ (SampleBatch.PREV_ACTIONS, self._prev_action_input), (SampleBatch.PREV_REWARDS, self._prev_reward_input), (SampleBatch.CUR_OBS, self._obs_input), ] for k, v in postprocessed_batch.items(): if k in batch_tensors: continue elif v.dtype == np.object: continue # can't handle arbitrary objects in TF shape = (None, ) + v.shape[1:] dtype = np.float32 if v.dtype == np.float64 else v.dtype placeholder = tf.placeholder(dtype, shape=shape, name=k) batch_tensors[k] = placeholder if log_once("loss_init"): logger.info( "Initializing loss function with dummy input:\n\n{}\n".format( summarize(batch_tensors))) loss = self._loss_fn(self, batch_tensors) if self._stats_fn: self._stats_fetches.update(self._stats_fn(self, batch_tensors)) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) TFPolicyGraph._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) self._sess.run(tf.global_variables_initializer())