Ejemplo n.º 1
0
    def test_traj_view_lstm_functionality(self):
        action_space = Box(-float("inf"), float("inf"), shape=(3, ))
        obs_space = Box(float("-inf"), float("inf"), (4, ))
        max_seq_len = 50
        rollout_fragment_length = 200
        assert rollout_fragment_length % max_seq_len == 0
        policies = {
            "pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
        }

        def policy_fn(agent_id):
            return "pol0"

        config = {
            "multiagent": {
                "policies": policies,
                "policy_mapping_fn": policy_fn,
            },
            "model": {
                "use_lstm": True,
                "max_seq_len": max_seq_len,
            },
        },

        rollout_worker_w_api = RolloutWorker(
            env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
            policy_config=dict(config, **{"_use_trajectory_view_api": True}),
            rollout_fragment_length=rollout_fragment_length,
            policy_spec=policies,
            policy_mapping_fn=policy_fn,
            num_envs=1,
        )
        rollout_worker_wo_api = RolloutWorker(
            env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
            policy_config=dict(config, **{"_use_trajectory_view_api": False}),
            rollout_fragment_length=rollout_fragment_length,
            policy_spec=policies,
            policy_mapping_fn=policy_fn,
            num_envs=1,
        )
        for iteration in range(20):
            result = rollout_worker_w_api.sample()
            check(result.count, rollout_fragment_length)
            pol_batch_w = result.policy_batches["pol0"]
            assert pol_batch_w.count >= rollout_fragment_length
            analyze_rnn_batch(pol_batch_w, max_seq_len)

            result = rollout_worker_wo_api.sample()
            pol_batch_wo = result.policy_batches["pol0"]
            check(pol_batch_w.data, pol_batch_wo.data)
Ejemplo n.º 2
0
    def test_traj_view_attention_functionality(self):
        action_space = Box(-float("inf"), float("inf"), shape=(3, ))
        obs_space = Box(float("-inf"), float("inf"), (4, ))
        max_seq_len = 50
        rollout_fragment_length = 201
        policies = {
            "pol0":
            (EpisodeEnvAwareAttentionPolicy, obs_space, action_space, {}),
        }

        def policy_fn(agent_id):
            return "pol0"

        config = {
            "multiagent": {
                "policies": policies,
                "policy_mapping_fn": policy_fn,
            },
            "model": {
                "max_seq_len": max_seq_len,
            },
        },

        rollout_worker_w_api = RolloutWorker(
            env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
            policy_config=dict(config, **{"_use_trajectory_view_api": True}),
            rollout_fragment_length=rollout_fragment_length,
            policy_spec=policies,
            policy_mapping_fn=policy_fn,
            num_envs=1,
        )
        batch = rollout_worker_w_api.sample()
        print(batch)
Ejemplo n.º 3
0
    def test_traj_view_attention_functionality(self):
        action_space = Box(float("-inf"), float("inf"), shape=(3, ))
        obs_space = Box(float("-inf"), float("inf"), (4, ))
        max_seq_len = 50
        rollout_fragment_length = 201
        policies = {
            "pol0":
            (EpisodeEnvAwareAttentionPolicy, obs_space, action_space, {}),
        }

        def policy_fn(agent_id, episode, **kwargs):
            return "pol0"

        config = {
            "multiagent": {
                "policies": policies,
                "policy_mapping_fn": policy_fn,
            },
            "model": {
                "max_seq_len": max_seq_len,
            },
        }

        rollout_worker_w_api = RolloutWorker(
            env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
            policy_config=config,
            rollout_fragment_length=rollout_fragment_length,
            policy_spec=policies,
            policy_mapping_fn=policy_fn,
            normalize_actions=False,
            num_envs=1,
        )
        batch = rollout_worker_w_api.sample()  # noqa: F841
Ejemplo n.º 4
0
    def test_traj_view_lstm_functionality(self):
        action_space = Box(float("-inf"), float("inf"), shape=(3, ))
        obs_space = Box(float("-inf"), float("inf"), (4, ))
        max_seq_len = 50
        rollout_fragment_length = 200
        assert rollout_fragment_length % max_seq_len == 0
        policies = {
            "pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
        }

        def policy_fn(agent_id, episode, **kwargs):
            return "pol0"

        config = {
            "multiagent": {
                "policies": policies,
                "policy_mapping_fn": policy_fn,
            },
            "model": {
                "use_lstm": True,
                "max_seq_len": max_seq_len,
            },
        }

        rw = RolloutWorker(
            env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
            policy_config=config,
            rollout_fragment_length=rollout_fragment_length,
            policy_spec=policies,
            policy_mapping_fn=policy_fn,
            normalize_actions=False,
            num_envs=1,
        )

        for iteration in range(20):
            result = rw.sample()
            check(result.count, rollout_fragment_length)
            pol_batch_w = result.policy_batches["pol0"]
            assert pol_batch_w.count >= rollout_fragment_length
            analyze_rnn_batch(
                pol_batch_w,
                max_seq_len,
                view_requirements=rw.policy_map["pol0"].view_requirements,
            )
Ejemplo n.º 5
0
    def test_traj_view_lstm_functionality(self):
        action_space = Box(-float("inf"), float("inf"), shape=(2, ))
        obs_space = Box(float("-inf"), float("inf"), (4, ))
        max_seq_len = 50
        policies = {
            "pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}),
        }

        def policy_fn(agent_id):
            return "pol0"

        rollout_worker = RolloutWorker(
            env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
            policy_config={
                "multiagent": {
                    "policies": policies,
                    "policy_mapping_fn": policy_fn,
                },
                "_use_trajectory_view_api": True,
                "model": {
                    "use_lstm": True,
                    "_time_major": True,
                    "max_seq_len": max_seq_len,
                },
            },
            policy=policies,
            policy_mapping_fn=policy_fn,
            num_envs=1,
        )
        for i in range(100):
            pc = rollout_worker.sampler.sample_collector. \
                policy_sample_collectors["pol0"]
            sample_batch_offset_before = pc.sample_batch_offset
            buffers = pc.buffers
            result = rollout_worker.sample()
            pol_batch = result.policy_batches["pol0"]

            self.assertTrue(result.count == 100)
            self.assertTrue(pol_batch.count >= 100)
            self.assertFalse(0 in pol_batch.seq_lens)
            # Check prev_reward/action, next_obs consistency.
            for t in range(max_seq_len):
                obs_t = pol_batch["obs"][t]
                r_t = pol_batch["rewards"][t]
                if t > 0:
                    next_obs_t_m_1 = pol_batch["new_obs"][t - 1]
                    self.assertTrue((obs_t == next_obs_t_m_1).all())
                if t < max_seq_len - 1:
                    prev_rewards_t_p_1 = pol_batch["prev_rewards"][t + 1]
                    self.assertTrue((r_t == prev_rewards_t_p_1).all())

            # Check the sanity of all the buffers in the un underlying
            # PerPolicy collector.
            for sample_batch_slot, agent_slot in enumerate(
                    range(sample_batch_offset_before, pc.sample_batch_offset)):
                t_buf = buffers["t"][:, agent_slot]
                obs_buf = buffers["obs"][:, agent_slot]
                # Skip empty seqs at end (these won't be part of the batch
                # and have been copied to new agent-slots (even if seq-len=0)).
                if sample_batch_slot < len(pol_batch.seq_lens):
                    seq_len = pol_batch.seq_lens[sample_batch_slot]
                    # Make sure timesteps are always increasing within the seq.
                    assert all(t_buf[1] + j == n + 1
                               for j, n in enumerate(t_buf)
                               if j < seq_len and j != 0)
                    # Make sure all obs within seq are non-0.0.
                    assert all(
                        any(obs_buf[j] != 0.0) for j in range(1, seq_len + 1))

            # Check seq-lens.
            for agent_slot, seq_len in enumerate(pol_batch.seq_lens):
                if seq_len < max_seq_len - 1:
                    # At least in the beginning, the next slots should always
                    # be empty (once all agent slots have been used once, these
                    # may be filled with "old" values (from longer sequences)).
                    if i < 10:
                        self.assertTrue(
                            (pol_batch["obs"][seq_len +
                                              1][agent_slot] == 0.0).all())
                    print(end="")
                    self.assertFalse(
                        (pol_batch["obs"][seq_len][agent_slot] == 0.0).all())