Beispiel #1
0
def _postprocess_dqn(policy_graph, sample_batch):
    obs, actions, rewards, new_obs, dones = [
        list(x) for x in sample_batch.columns(
            ["obs", "actions", "rewards", "new_obs", "dones"])
    ]

    # N-step Q adjustments
    if policy_graph.config["n_step"] > 1:
        _adjust_nstep(policy_graph.config["n_step"],
                      policy_graph.config["gamma"], obs, actions, rewards,
                      new_obs, dones)

    batch = SampleBatch({
        "obs": obs,
        "actions": actions,
        "rewards": rewards,
        "new_obs": new_obs,
        "dones": dones,
        "weights": np.ones_like(rewards)
    })

    # Prioritize on the worker side
    if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
        td_errors = policy_graph.compute_td_error(
            batch["obs"], batch["actions"], batch["rewards"], batch["new_obs"],
            batch["dones"], batch["weights"])
        new_priorities = (
            np.abs(td_errors) + policy_graph.config["prioritized_replay_eps"])
        batch.data["weights"] = new_priorities

    return batch
Beispiel #2
0
def _from_json(batch):
    if isinstance(batch, bytes):  # smart_open S3 doesn't respect "r"
        batch = batch.decode("utf-8")
    data = json.loads(batch)

    if "type" in data:
        data_type = data.pop("type")
    else:
        raise ValueError("JSON record missing 'type' field")

    if data_type == "SampleBatch":
        for k, v in data.items():
            data[k] = unpack_if_needed(v)
        return SampleBatch(data)
    elif data_type == "MultiAgentBatch":
        policy_batches = {}
        for policy_id, policy_batch in data["policy_batches"].items():
            inner = {}
            for k, v in policy_batch.items():
                inner[k] = unpack_if_needed(v)
            policy_batches[policy_id] = SampleBatch(inner)
        return MultiAgentBatch(policy_batches, data["count"])
    else:
        raise ValueError(
            "Type field must be one of ['SampleBatch', 'MultiAgentBatch']",
            data_type)
Beispiel #3
0
def _postprocess_dqn(policy_graph, sample_batch):
    obs, actions, rewards, new_obs, dones = [
        list(x) for x in sample_batch.columns(
            ["obs", "actions", "rewards", "new_obs", "dones"])
    ]

    # N-step Q adjustments
    if policy_graph.config["n_step"] > 1:
        adjust_nstep(policy_graph.config["n_step"],
                     policy_graph.config["gamma"], obs, actions, rewards,
                     new_obs, dones)

    batch = SampleBatch({
        "obs": obs,
        "actions": actions,
        "rewards": rewards,
        "new_obs": new_obs,
        "dones": dones,
        "weights": np.ones_like(rewards)
    })

    # Prioritize on the worker side
    if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
        td_errors = policy_graph.compute_td_error(
            batch["obs"], batch["actions"], batch["rewards"], batch["new_obs"],
            batch["dones"], batch["weights"])
        new_priorities = (
            np.abs(td_errors) + policy_graph.config["prioritized_replay_eps"])
        batch.data["weights"] = new_priorities

    return batch
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray.get([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.learn_on_batch(samples)
                if "stats" in fetches:
                    self.learner_stats = fetches["stats"]
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches
Beispiel #5
0
def _from_json(batch):
    if isinstance(batch, bytes):  # smart_open S3 doesn't respect "r"
        batch = batch.decode("utf-8")
    data = json.loads(batch)
    for k, v in data.items():
        data[k] = [unpack_if_needed(x) for x in unpack_if_needed(v)]
    return SampleBatch(data)
Beispiel #6
0
def collect_samples(agents, timesteps_per_batch):
    num_timesteps_so_far = 0
    trajectories = []
    # This variable maps the object IDs of trajectories that are currently
    # computed to the agent that they are computed on; we start some initial
    # tasks here.

    agent_dict = {}

    for agent in agents:
        fut_sample = agent.sample.remote()
        agent_dict[fut_sample] = agent

    while num_timesteps_so_far < timesteps_per_batch:
        # TODO(pcm): Make wait support arbitrary iterators and remove the
        # conversion to list here.
        [fut_sample], _ = ray.wait(list(agent_dict))
        agent = agent_dict.pop(fut_sample)
        # Start task with next trajectory and record it in the dictionary.
        fut_sample2 = agent.sample.remote()
        agent_dict[fut_sample2] = agent

        next_sample = ray.get(fut_sample)
        num_timesteps_so_far += next_sample.count
        trajectories.append(next_sample)
    return SampleBatch.concat_samples(trajectories)
Beispiel #7
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                samples = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                samples = self.local_evaluator.sample()

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.compute_apply(samples)
                if self.num_sgd_iter > 1:
                    print(i, fetches)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches
Beispiel #8
0
 def _replay(self):
     samples = {}
     with self.replay_timer:
         for policy_id, replay_buffer in self.replay_buffers.items():
             if isinstance(replay_buffer, PrioritizedReplayBuffer):
                 (obses_t, actions, rewards, obses_tp1, dones, weights,
                  batch_indexes) = replay_buffer.sample(
                      self.train_batch_size,
                      beta=self.prioritized_replay_beta.value(
                          self.num_steps_trained))
             else:
                 (obses_t, actions, rewards, obses_tp1,
                  dones) = replay_buffer.sample(self.train_batch_size)
                 weights = np.ones_like(rewards)
                 batch_indexes = -np.ones_like(rewards)
             samples[policy_id] = SampleBatch({
                 "obs": obses_t,
                 "actions": actions,
                 "rewards": rewards,
                 "new_obs": obses_tp1,
                 "dones": dones,
                 "weights": weights,
                 "batch_indexes": batch_indexes
             })
     return MultiAgentBatch(samples, self.train_batch_size)
Beispiel #9
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({DEFAULT_POLICY_ID: batch},
                                        batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                batch = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                batch = self.local_evaluator.sample()

            # Handle everything as if multiagent
            if isinstance(batch, SampleBatch):
                batch = MultiAgentBatch({
                    DEFAULT_POLICY_ID: batch
                }, batch.count)

            for policy_id, s in batch.policy_batches.items():
                for row in s.rows():
                    self.replay_buffers[policy_id].add(
                        pack_if_needed(row["obs"]),
                        row["actions"],
                        row["rewards"],
                        pack_if_needed(row["new_obs"]),
                        row["dones"],
                        weight=None)

        if self.num_steps_sampled >= self.replay_starts:
            self._optimize()

        self.num_steps_sampled += batch.count
Beispiel #11
0
 def update_episode_buffer(self, samples):
     if self.config["episode_mode"] == "episode_buffer":
         # assert True
         new_batches = list(separate_sample_batch(samples).values())
         if self._episode_buffer:
             old_batches = \
                 list(separate_sample_batch(self._episode_buffer).values())
         else:
             old_batches = []
         all_batches = old_batches + new_batches
         all_batches = sorted(all_batches,
                              key=lambda x: x["rewards"].sum(),
                              reverse=True)
         buffer_size = self.config["buffer_size"]
         # assert buffer_size == 24
         self._episode_buffer = SampleBatch.concat_samples(
             all_batches[:buffer_size])
     elif self.config["episode_mode"] == "last_episodes":
         # assert False
         self._episode_buffer = samples
     elif self.config["episode_mode"] == "all_episodes":
         # assert False
         if self._episode_buffer:
             self._episode_buffer = self._episode_buffer.concat(samples)
         else:
             self._episode_buffer = samples
     else:
         raise NotImplementedError
     self._rnn_state_out = None
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray.get([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)

        with self.grad_timer:
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.compute_apply(samples)
                if "stats" in fetches:
                    self.learner_stats = fetches["stats"]
                if self.num_sgd_iter > 1:
                    logger.debug("{} {}".format(i, fetches))
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
        return fetches
Beispiel #13
0
def collect_samples_straggler_mitigation(agents, train_batch_size):
    """Collects at least train_batch_size samples.

    This is the legacy behavior as of 0.6, and launches extra sample tasks to
    potentially improve performance but can result in many wasted samples.
    """

    num_timesteps_so_far = 0
    trajectories = []
    agent_dict = {}

    for agent in agents:
        fut_sample = agent.sample.remote()
        agent_dict[fut_sample] = agent

    while num_timesteps_so_far < train_batch_size:
        # TODO(pcm): Make wait support arbitrary iterators and remove the
        # conversion to list here.
        [fut_sample], _ = ray.wait(list(agent_dict))
        agent = agent_dict.pop(fut_sample)
        # Start task with next trajectory and record it in the dictionary.
        fut_sample2 = agent.sample.remote()
        agent_dict[fut_sample2] = agent

        next_sample = ray.get(fut_sample)
        num_timesteps_so_far += next_sample.count
        trajectories.append(next_sample)

    logger.info("Discarding {} sample tasks".format(len(agent_dict)))
    return SampleBatch.concat_samples(trajectories)
Beispiel #14
0
def collect_samples(agents, sample_batch_size, num_envs_per_worker,
                    train_batch_size):
    """Collects at least train_batch_size samples, never discarding any."""

    num_timesteps_so_far = 0
    trajectories = []
    agent_dict = {}

    for agent in agents:
        fut_sample = agent.sample.remote()
        agent_dict[fut_sample] = agent

    while agent_dict:
        [fut_sample], _ = ray.wait(list(agent_dict))
        agent = agent_dict.pop(fut_sample)
        next_sample = ray.get(fut_sample)
        assert next_sample.count >= sample_batch_size * num_envs_per_worker
        num_timesteps_so_far += next_sample.count
        trajectories.append(next_sample)

        # Only launch more tasks if we don't already have enough pending
        pending = len(agent_dict) * sample_batch_size * num_envs_per_worker
        if num_timesteps_so_far + pending < train_batch_size:
            fut_sample2 = agent.sample.remote()
            agent_dict[fut_sample2] = agent

    return SampleBatch.concat_samples(trajectories)
Beispiel #15
0
 def testBatchIds(self):
     ev = PolicyEvaluator(env_creator=lambda _: gym.make("CartPole-v0"),
                          policy_graph=MockPolicyGraph)
     batch1 = ev.sample()
     batch2 = ev.sample()
     self.assertEqual(len(set(batch1["unroll_id"])), 1)
     self.assertEqual(len(set(batch2["unroll_id"])), 1)
     self.assertEqual(
         len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
Beispiel #16
0
def postprocess_trajectory(samples, baseline, gamma, lambda_, use_gae):
    separated_samples = separate_sample_batch(samples)
    baseline.fit(separated_samples.values())
    for eps_id, values in separated_samples.items():
        values["vf_preds"] = baseline.predict(values)
        separated_samples[eps_id] = compute_advantages(values, 0.0, gamma,
                                                       lambda_, use_gae)
    samples = SampleBatch.concat_samples(list(separated_samples.values()))
    return samples
Beispiel #17
0
    def build_and_reset(self):
        """Returns a sample batch including all previously added values."""

        batch = SampleBatch(
            {k: to_float_array(v)
             for k, v in self.buffers.items()})
        self.buffers.clear()
        self.count = 0
        return batch
Beispiel #18
0
    def build_and_reset(self):
        """Returns a sample batch including all previously added values."""

        batch = SampleBatch(
            {k: to_float_array(v)
             for k, v in self.buffers.items()})
        batch.data[SampleBatch.UNROLL_ID] = np.repeat(self.unroll_id,
                                                      batch.count)
        self.buffers.clear()
        self.count = 0
        self.unroll_id += 1
        return batch
Beispiel #19
0
 def _optimize(self):
     samples = [random.choice(self.replay_buffer)]
     while sum(s.count for s in samples) < self.train_batch_size:
         samples.append(random.choice(self.replay_buffer))
     samples = SampleBatch.concat_samples(samples)
     with self.grad_timer:
         info_dict = self.local_evaluator.compute_apply(samples)
         for policy_id, info in info_dict.items():
             if "stats" in info:
                 self.learner_stats[policy_id] = info["stats"]
         self.grad_timer.push_units_processed(samples.count)
     self.num_steps_trained += samples.count
 def _optimize(self):
     samples = [random.choice(self.replay_buffer)]
     while sum(s.count for s in samples) < self.train_batch_size:
         samples.append(random.choice(self.replay_buffer))
     samples = SampleBatch.concat_samples(samples)
     with self.grad_timer:
         info_dict = self.local_evaluator.compute_apply(samples)
         for policy_id, info in info_dict.items():
             if "stats" in info:
                 self.learner_stats[policy_id] = info["stats"]
         self.grad_timer.push_units_processed(samples.count)
     self.num_steps_trained += samples.count
Beispiel #21
0
def separate_sample_batch(sample_batch):
    separated_sample_batch = defaultdict(lambda: defaultdict(list))
    for i, eps_id in enumerate(sample_batch["eps_id"]):
        for key in sample_batch.keys():
            separated_sample_batch[eps_id][key].append(sample_batch[key][i])
    for eps_id, values in separated_sample_batch.items():
        for k, v in values.items():
            separated_sample_batch[eps_id][k] = np.stack(v)
        separated_sample_batch[eps_id] = SampleBatch(
            dict(separated_sample_batch[eps_id]))
    separated_sample_batch = dict(separated_sample_batch)
    return separated_sample_batch
Beispiel #22
0
def compute_advantages(rollout, gamma=1, modify=False):
    """Given a rollout, compute its value targets and the advantage.

    Args:
        rollout (SampleBatch): SampleBatch of a single trajectory
        last_r (float): Value estimation for last observation
        gamma (float): Discount factor.
        lambda_ (float): Parameter for GAE
        use_gae (bool): Using Generalized Advantage Estamation

    Returns:
        SampleBatch (SampleBatch): Object with experience from rollout and
            processed rewards.
    """

    traj = {}

    trajsize = len(rollout["actions"])
    for key in rollout:
        traj[key] = np.stack(rollout[key])

    rewards = traj['rewards']

    gammas = np.power(gamma, np.arange(trajsize))
    cum_ret_t = np.zeros(trajsize)
    for t in range(trajsize):
        if t == 0:
            cum_ret_t[t] = np.cumprod(1 + rewards * gammas)[-1]
        else:
            cum_ret_t[t] = np.cumprod(1 + rewards[t:] * gammas[:-t])[-1]

    cum_ret_t -= 1
    if modify:
        cum_ret_t[(-0.01 < cum_ret_t) & (cum_ret_t <= 0)] = -0.01
        cum_ret_t *= 1000

    if 'vf_preds' in traj:
        traj["advantages"] = cum_ret_t - traj['vf_preds']
        traj["value_targets"] = (traj["advantages"] +
                                 traj["vf_preds"]).copy().astype(np.float32)
    else:
        traj["advantages"] = cum_ret_t
        traj["value_targets"] = traj["value_targets"] = np.zeros_like(
            traj["advantages"])

    traj["advantages"] = traj["advantages"].copy().astype(np.float32)

    assert all(val.shape[0] == trajsize for val in traj.values()), \
        "Rollout stacked incorrectly!"
    return SampleBatch(traj)
Beispiel #23
0
def compute_returns(rollout, last_r, gamma):
    traj = {}
    trajsize = len(rollout["actions"])
    for key in rollout:
        traj[key] = np.stack(rollout[key])

    rewards_plus_v = np.concatenate([rollout["rewards"], np.array([last_r])])
    traj["returns"] = discount(rewards_plus_v, gamma)[:-1]

    traj["returns"] = traj["returns"].copy().astype(np.float32)

    assert all(val.shape[0] == trajsize for val in traj.values()), \
        "Rollout stacked incorrectly!"
    return SampleBatch(traj)
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            samples = []
            while sum(s.count for s in samples) < self.train_batch_size:
                if self.remote_evaluators:
                    samples.extend(
                        ray.get([
                            e.sample.remote() for e in self.remote_evaluators
                        ]))
                else:
                    samples.append(self.local_evaluator.sample())
            samples = SampleBatch.concat_samples(samples)
            self.sample_timer.push_units_processed(samples.count)
            # print("\n\nhkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples", samples.keys())
            # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples obs", samples["obs"].shape)
            # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples new_obs", samples["new_obs"].shape)
            # print("hkkkkkkkkkkkkkkkkkkkkkkkkkkkkkkk samples actions", samples["actions"].shape)
            # import numpy
            # transition_state = numpy.stack([samples["obs"], samples["new_obs"]],  axis=2)
            # print("hkkkkk ", transition_state.shape)
            # print("hkkkkk ", transition_state.reshape(transition_state.shape[0], -1).shape)

            # print("samples[obs]", samples["obs"])
            # print("samples[actions]", samples["actions"])

        with self.grad_timer:

            new_samples = self.env_model.process(samples)
            # print("new_samples[obs]", new_samples["obs"])
            # print("new_samples[actions]", new_samples["actions"])
            for i in range(self.num_sgd_iter):
                fetches = self.local_evaluator.compute_apply(new_samples)
                if "stats" in fetches:
                    self.learner_stats = fetches["stats"]
                if self.num_sgd_iter > 1:
                    print(i, fetches)
            # self.grad_timer.push_units_processed(new_samples.count)
            self.grad_timer.push_units_processed(len(samples["obs"]))

        # self.num_steps_sampled += new_samples.count
        # self.num_steps_trained += new_samples.count
        self.num_steps_sampled += len(samples["obs"])
        self.num_steps_trained += len(samples["obs"])
        return fetches
Beispiel #25
0
    def _postprocess_if_needed(self, batch):
        if not self.ioctx.config.get("postprocess_inputs"):
            return batch

        if isinstance(batch, SampleBatch):
            out = []
            for sub_batch in batch.split_by_episode():
                out.append(self.ioctx.evaluator.policy_map[DEFAULT_POLICY_ID]
                           .postprocess_trajectory(sub_batch))
            return SampleBatch.concat_samples(out)
        else:
            # TODO(ekl) this is trickier since the alignments between agent
            # trajectories in the episode are not available any more.
            raise NotImplementedError(
                "Postprocessing of multi-agent data not implemented yet.")
Beispiel #26
0
    def _postprocess_if_needed(self, batch):
        if not self.ioctx.config.get("postprocess_inputs"):
            return batch

        if isinstance(batch, SampleBatch):
            out = []
            for sub_batch in batch.split_by_episode():
                out.append(self.ioctx.evaluator.policy_map[DEFAULT_POLICY_ID]
                           .postprocess_trajectory(sub_batch))
            return SampleBatch.concat_samples(out)
        else:
            # TODO(ekl) this is trickier since the alignments between agent
            # trajectories in the episode are not available any more.
            raise NotImplementedError(
                "Postprocessing of multi-agent data not implemented yet.")
Beispiel #27
0
def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
    """Given a rollout, compute its value targets and the advantage.

    Args:
        rollout (SampleBatch): SampleBatch of a single trajectory
        last_r (float): Value estimation for last observation
        gamma (float): Discount factor.
        lambda_ (float): Parameter for GAE
        use_gae (bool): Using Generalized Advantage Estamation

    Returns:
        SampleBatch (SampleBatch): Object with experience from rollout and
            processed rewards.
    """

    traj = {}
    trajsize = len(rollout[SampleBatch.ACTIONS])
    for key in rollout:
        traj[key] = np.stack(rollout[key])

    if use_gae:
        assert SampleBatch.VF_PREDS in rollout, "Values not found!"
        vpred_t = np.concatenate(
            [rollout[SampleBatch.VF_PREDS],
             np.array([last_r])])
        delta_t = (traj[SampleBatch.REWARDS] + gamma * vpred_t[1:] -
                   vpred_t[:-1])
        # This formula for the advantage comes
        # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
        traj[Postprocessing.ADVANTAGES] = discount(delta_t, gamma * lambda_)
        traj[Postprocessing.VALUE_TARGETS] = (
            traj[Postprocessing.ADVANTAGES] +
            traj[SampleBatch.VF_PREDS]).copy().astype(np.float32)
    else:
        rewards_plus_v = np.concatenate(
            [rollout[SampleBatch.REWARDS],
             np.array([last_r])])
        traj[Postprocessing.ADVANTAGES] = discount(rewards_plus_v, gamma)[:-1]
        # TODO(ekl): support using a critic without GAE
        traj[Postprocessing.VALUE_TARGETS] = np.zeros_like(
            traj[Postprocessing.ADVANTAGES])

    traj[Postprocessing.ADVANTAGES] = traj[
        Postprocessing.ADVANTAGES].copy().astype(np.float32)

    assert all(val.shape[0] == trajsize for val in traj.values()), \
        "Rollout stacked incorrectly!"
    return SampleBatch(traj)
    def replay(self):
        with self.replay_timer:
            if len(self.replay_buffer) < self.replay_starts:
                return None

            (obses_t, actions, rewards, obses_tp1, dones, weights,
             batch_indexes) = self.replay_buffer.sample(
                 self.train_batch_size, beta=self.prioritized_replay_beta)

            batch = SampleBatch({
                "obs": obses_t,
                "actions": actions,
                "rewards": rewards,
                "new_obs": obses_tp1,
                "dones": dones,
                "weights": weights,
                "batch_indexes": batch_indexes
            })
            return batch
Beispiel #29
0
    def replay(self):
        if self.num_added < self.replay_starts:
            return None

        with self.replay_timer:
            samples = {}
            for policy_id, replay_buffer in self.replay_buffers.items():
                (obses_t, actions, rewards, obses_tp1, dones, weights,
                 batch_indexes) = replay_buffer.sample(
                     self.train_batch_size, beta=self.prioritized_replay_beta)
                samples[policy_id] = SampleBatch({
                    "obs": obses_t,
                    "actions": actions,
                    "rewards": rewards,
                    "new_obs": obses_tp1,
                    "dones": dones,
                    "weights": weights,
                    "batch_indexes": batch_indexes
                })
            return MultiAgentBatch(samples, self.train_batch_size)
Beispiel #30
0
def compute_advantages(rollout, last_r, gamma, lambda_=1.0, use_gae=True):
    """Given a rollout, compute its value targets and the advantage.

    Args:
        rollout (PartialRollout): Partial Rollout Object
        last_r (float): Value estimation for last observation
        gamma (float): Parameter for GAE
        lambda_ (float): Parameter for GAE
        use_gae (bool): Using Generalized Advantage Estamation

    Returns:
        SampleBatch (SampleBatch): Object with experience from rollout and
            processed rewards.
    """

    traj = {}
    trajsize = len(rollout["actions"])
    for key in rollout:
        traj[key] = np.stack(rollout[key])

    if use_gae:
        assert "vf_preds" in rollout, "Values not found!"
        vpred_t = np.concatenate([rollout["vf_preds"], np.array([last_r])])
        delta_t = traj["rewards"] + gamma * vpred_t[1:] - vpred_t[:-1]
        # This formula for the advantage comes
        # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
        traj["advantages"] = discount(delta_t, gamma * lambda_)
        traj["value_targets"] = (traj["advantages"] +
                                 traj["vf_preds"]).copy().astype(np.float32)
    else:
        rewards_plus_v = np.concatenate(
            [rollout["rewards"], np.array([last_r])])
        traj["advantages"] = discount(rewards_plus_v, gamma)[:-1]

    traj["advantages"] = traj["advantages"].copy().astype(np.float32)

    assert all(val.shape[0] == trajsize for val in traj.values()), \
        "Rollout stacked incorrectly!"
    return SampleBatch(traj)
Beispiel #31
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                samples = SampleBatch.concat_samples(
                    ray.get(
                        [e.sample.remote() for e in self.remote_evaluators]))
            else:
                samples = self.local_evaluator.sample()

        with self.grad_timer:
            grad, _ = self.local_evaluator.compute_gradients(samples)
            self.local_evaluator.apply_gradients(grad)
            self.grad_timer.push_units_processed(samples.count)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += samples.count
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                if self.straggler_mitigation:
                    samples = collect_samples_straggler_mitigation(
                        self.remote_evaluators, self.train_batch_size)
                else:
                    samples = collect_samples(
                        self.remote_evaluators, self.sample_batch_size,
                        self.num_envs_per_worker, self.train_batch_size)
                if samples.count > self.train_batch_size * 2:
                    logger.info(
                        "Collected more training samples than expected "
                        "(actual={}, train_batch_size={}). ".format(
                            samples.count, self.train_batch_size) +
                        "This may be because you have many workers or "
                        "long episodes in 'complete_episodes' batch mode.")
            else:
                samples = []
                while sum(s.count for s in samples) < self.train_batch_size:
                    samples.append(self.local_evaluator.sample())
                samples = SampleBatch.concat_samples(samples)

            # Handle everything as if multiagent
            if isinstance(samples, SampleBatch):
                samples = MultiAgentBatch({
                    DEFAULT_POLICY_ID: samples
                }, samples.count)

        for policy_id, policy in self.policies.items():
            if policy_id not in samples.policy_batches:
                continue

            batch = samples.policy_batches[policy_id]
            for field in self.standardize_fields:
                value = batch[field]
                standardized = (value - value.mean()) / max(1e-4, value.std())
                batch[field] = standardized

            # Important: don't shuffle RNN sequence elements
            if not policy._state_inputs:
                batch.shuffle()

        num_loaded_tuples = {}
        with self.load_timer:
            for policy_id, batch in samples.policy_batches.items():
                if policy_id not in self.policies:
                    continue

                policy = self.policies[policy_id]
                tuples = policy._get_loss_inputs_dict(batch)
                data_keys = [ph for _, ph in policy._loss_inputs]
                if policy._state_inputs:
                    state_keys = policy._state_inputs + [policy._seq_lens]
                else:
                    state_keys = []
                num_loaded_tuples[policy_id] = (
                    self.optimizers[policy_id].load_data(
                        self.sess, [tuples[k] for k in data_keys],
                        [tuples[k] for k in state_keys]))

        fetches = {}
        with self.grad_timer:
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                optimizer = self.optimizers[policy_id]
                num_batches = max(
                    1,
                    int(tuples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        batch_fetches = optimizer.optimize(
                            self.sess, permutation[batch_index] *
                            self.per_device_batch_size)
                        for k, v in batch_fetches.items():
                            iter_extra_fetches[k].append(v)
                    logger.debug("{} {}".format(i,
                                                _averaged(iter_extra_fetches)))
                fetches[policy_id] = _averaged(iter_extra_fetches)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += tuples_per_device * len(self.devices)
        return fetches
Beispiel #33
0
    def step(self):
        with self.update_weights_timer:
            if self.remote_evaluators:
                weights = ray.put(self.local_evaluator.get_weights())
                for e in self.remote_evaluators:
                    e.set_weights.remote(weights)

        with self.sample_timer:
            if self.remote_evaluators:
                if self.straggler_mitigation:
                    samples = collect_samples_straggler_mitigation(
                        self.remote_evaluators, self.train_batch_size)
                else:
                    samples = collect_samples(
                        self.remote_evaluators, self.sample_batch_size,
                        self.num_envs_per_worker, self.train_batch_size)
                if samples.count > self.train_batch_size * 2:
                    logger.info(
                        "Collected more training samples than expected "
                        "(actual={}, train_batch_size={}). ".format(
                            samples.count, self.train_batch_size) +
                        "This may be because you have many workers or "
                        "long episodes in 'complete_episodes' batch mode.")
            else:
                samples = []
                while sum(s.count for s in samples) < self.train_batch_size:
                    samples.append(self.local_evaluator.sample())
                samples = SampleBatch.concat_samples(samples)

            # Handle everything as if multiagent
            if isinstance(samples, SampleBatch):
                samples = MultiAgentBatch({
                    DEFAULT_POLICY_ID: samples
                }, samples.count)

        for policy_id, policy in self.policies.items():
            if policy_id not in samples.policy_batches:
                continue

            batch = samples.policy_batches[policy_id]
            for field in self.standardize_fields:
                value = batch[field]
                standardized = (value - value.mean()) / max(1e-4, value.std())
                batch[field] = standardized

            # Important: don't shuffle RNN sequence elements
            if not policy._state_inputs:
                batch.shuffle()

        num_loaded_tuples = {}
        with self.load_timer:
            for policy_id, batch in samples.policy_batches.items():
                if policy_id not in self.policies:
                    continue

                policy = self.policies[policy_id]
                tuples = policy._get_loss_inputs_dict(batch)
                data_keys = [ph for _, ph in policy._loss_inputs]
                if policy._state_inputs:
                    state_keys = policy._state_inputs + [policy._seq_lens]
                else:
                    state_keys = []
                num_loaded_tuples[policy_id] = (
                    self.optimizers[policy_id].load_data(
                        self.sess, [tuples[k] for k in data_keys],
                        [tuples[k] for k in state_keys]))

        fetches = {}
        with self.grad_timer:
            for policy_id, tuples_per_device in num_loaded_tuples.items():
                optimizer = self.optimizers[policy_id]
                num_batches = max(
                    1,
                    int(tuples_per_device) // int(self.per_device_batch_size))
                logger.debug("== sgd epochs for {} ==".format(policy_id))
                for i in range(self.num_sgd_iter):
                    iter_extra_fetches = defaultdict(list)
                    permutation = np.random.permutation(num_batches)
                    for batch_index in range(num_batches):
                        batch_fetches = optimizer.optimize(
                            self.sess, permutation[batch_index] *
                            self.per_device_batch_size)
                        for k, v in batch_fetches[LEARNER_STATS_KEY].items():
                            iter_extra_fetches[k].append(v)
                    logger.debug("{} {}".format(i,
                                                _averaged(iter_extra_fetches)))
                fetches[policy_id] = _averaged(iter_extra_fetches)

        self.num_steps_sampled += samples.count
        self.num_steps_trained += tuples_per_device * len(self.devices)
        self.learner_stats = fetches
        return fetches
Beispiel #34
0
    def _initialize_loss(self):
        def fake_array(tensor):
            shape = tensor.shape.as_list()
            shape[0] = 1
            return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype)

        dummy_batch = {
            SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input),
            SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input),
            SampleBatch.CUR_OBS: fake_array(self._obs_input),
            SampleBatch.NEXT_OBS: fake_array(self._obs_input),
            SampleBatch.ACTIONS: fake_array(self._prev_action_input),
            SampleBatch.REWARDS: np.array([0], dtype=np.float32),
            SampleBatch.DONES: np.array([False], dtype=np.bool),
        }
        state_init = self.get_initial_state()
        for i, h in enumerate(state_init):
            dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0)
            dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0)
        if state_init:
            dummy_batch["seq_lens"] = np.array([1], dtype=np.int32)
        for k, v in self.extra_compute_action_fetches().items():
            dummy_batch[k] = fake_array(v)

        # postprocessing might depend on variable init, so run it first here
        self._sess.run(tf.global_variables_initializer())
        postprocessed_batch = self.postprocess_trajectory(
            SampleBatch(dummy_batch))

        batch_tensors = UsageTrackingDict({
            SampleBatch.PREV_ACTIONS: self._prev_action_input,
            SampleBatch.PREV_REWARDS: self._prev_reward_input,
            SampleBatch.CUR_OBS: self._obs_input,
        })
        loss_inputs = [
            (SampleBatch.PREV_ACTIONS, self._prev_action_input),
            (SampleBatch.PREV_REWARDS, self._prev_reward_input),
            (SampleBatch.CUR_OBS, self._obs_input),
        ]

        for k, v in postprocessed_batch.items():
            if k in batch_tensors:
                continue
            elif v.dtype == np.object:
                continue  # can't handle arbitrary objects in TF
            shape = (None, ) + v.shape[1:]
            dtype = np.float32 if v.dtype == np.float64 else v.dtype
            placeholder = tf.placeholder(dtype, shape=shape, name=k)
            batch_tensors[k] = placeholder

        if log_once("loss_init"):
            logger.info(
                "Initializing loss function with dummy input:\n\n{}\n".format(
                    summarize(batch_tensors)))

        loss = self._loss_fn(self, batch_tensors)
        if self._stats_fn:
            self._stats_fetches.update(self._stats_fn(self, batch_tensors))
        for k in sorted(batch_tensors.accessed_keys):
            loss_inputs.append((k, batch_tensors[k]))
        TFPolicyGraph._initialize_loss(self, loss, loss_inputs)
        if self._grad_stats_fn:
            self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
        self._sess.run(tf.global_variables_initializer())