Exemplo n.º 1
0
def test_optimize_model(policy):
    obs_space, action_space = policy.observation_space, policy.action_space
    train_samples = fake_batch(obs_space, action_space, batch_size=80)
    eval_samples = fake_batch(obs_space, action_space, batch_size=20)

    losses, info = policy.optimize_model(train_samples, eval_samples)

    assert isinstance(losses, list)
    assert all(isinstance(loss, float) for loss in losses)

    assert isinstance(info, dict)
    assert "model_epochs" in info
    assert info["model_epochs"] >= 0
    assert "train_loss(models)" in info
    assert "eval_loss(models)" in info
    assert "grad_norm(models)" in info
Exemplo n.º 2
0
def test_generate_virtual_sample_batch(policy, rollout_schedule):
    obs_space, action_space = policy.observation_space, policy.action_space
    initial_states = 10
    samples = fake_batch(obs_space, action_space, batch_size=initial_states)
    batch = policy.generate_virtual_sample_batch(samples)

    assert isinstance(batch, SampleBatch)
    assert SampleBatch.CUR_OBS in batch
    assert SampleBatch.ACTIONS in batch
    assert SampleBatch.NEXT_OBS in batch
    assert SampleBatch.REWARDS in batch
    assert SampleBatch.DONES in batch

    policy.global_timestep = 10
    for timestep, value in rollout_schedule:
        if policy.global_timestep >= timestep:
            break
    min_length = value

    min_count = min_length * initial_states
    assert batch.count >= min_count
    assert batch[
        SampleBatch.CUR_OBS].shape == (batch.count, ) + obs_space.shape
    assert batch[
        SampleBatch.ACTIONS].shape == (batch.count, ) + action_space.shape
    assert batch[
        SampleBatch.NEXT_OBS].shape == (batch.count, ) + obs_space.shape
    assert batch[SampleBatch.REWARDS].shape == (batch.count, )
    assert batch[SampleBatch.REWARDS].shape == (batch.count, )
Exemplo n.º 3
0
    def make_module_and_batch(module_cls, config):
        config["torch_script"] = torch_script
        module = module_cls(obs_space, action_space, config)

        batch = UsageTrackingDict(
            fake_batch(obs_space, action_space, batch_size=10))
        batch.set_get_interceptor(partial(convert_to_tensor, device="cpu"))

        return torch.jit.script(module) if torch_script else module, batch
Exemplo n.º 4
0
def test_improve_policy(trainer_cls, envs, config):
    # pylint:disable=unused-argument
    trainer = trainer_cls(env="MockEnv", config=config)
    env = trainer.workers.local_worker().env

    real_samples = fake_batch(env.observation_space,
                              env.action_space,
                              batch_size=80)
    for row in real_samples.rows():
        trainer.replay.add(row)
    virtual_samples = fake_batch(env.observation_space,
                                 env.action_space,
                                 batch_size=800)
    for row in virtual_samples.rows():
        trainer.virtual_replay.add(row)

    info = trainer.improve_policy(1)
    assert "learner" not in info
    assert "learner_stats" not in info
Exemplo n.º 5
0
def test_madpg_loss(policy_and_env):
    policy, _ = policy_and_env
    batch = policy.lazy_tensor_dict(
        fake_batch(policy.observation_space,
                   policy.action_space,
                   batch_size=10))

    loss, info = policy.loss_actor(batch)
    assert isinstance(info, dict)
    assert loss.shape == ()
    assert loss.dtype == torch.float32
    assert loss.grad_fn is not None

    policy.module.zero_grad()
    loss.backward()
    assert all(p.grad is not None and torch.isfinite(p.grad).all()
               and not torch.isnan(p.grad).all()
               for p in policy.module.actor.parameters())
Exemplo n.º 6
0
def test_generate_virtual_sample_batch(policy):
    obs_space, action_space = policy.observation_space, policy.action_space
    initial_states = 10
    samples = fake_batch(obs_space, action_space, batch_size=initial_states)
    batch = policy.generate_virtual_sample_batch(samples)

    assert isinstance(batch, SampleBatch)
    assert SampleBatch.CUR_OBS in batch
    assert SampleBatch.ACTIONS in batch
    assert SampleBatch.NEXT_OBS in batch
    assert SampleBatch.REWARDS in batch
    assert SampleBatch.DONES in batch

    total_count = policy.model_sampling_spec.rollout_length * initial_states
    assert batch.count == total_count
    assert batch[
        SampleBatch.CUR_OBS].shape == (total_count, ) + obs_space.shape
    assert batch[
        SampleBatch.ACTIONS].shape == (total_count, ) + action_space.shape
    assert batch[
        SampleBatch.NEXT_OBS].shape == (total_count, ) + obs_space.shape
    assert batch[SampleBatch.REWARDS].shape == (total_count, )
    assert batch[SampleBatch.REWARDS].shape == (total_count, )
Exemplo n.º 7
0
def sample_batch(obs_space, action_space):
    return fake_batch(obs_space, action_space, batch_size=10)
Exemplo n.º 8
0
def env_samples(env_):
    return fake_batch(env_.observation_space, env_.action_space, batch_size=10)
Exemplo n.º 9
0
def make_batch(obs_space, action_space, batch_size=4):
    batch = UsageTrackingDict(
        fake_batch(obs_space, action_space, batch_size=batch_size))
    batch.set_get_interceptor(partial(convert_to_tensor, device="cpu"))
    return batch
Exemplo n.º 10
0
 def make_policy_and_batch(policy_cls, config):
     policy = policy_cls(obs_space, action_space, config)
     batch = policy.lazy_tensor_dict(
         fake_batch(obs_space, action_space, batch_size=10))
     return policy, batch
Exemplo n.º 11
0
def batch(obs_space, action_space):
    from raylab.utils.debug import fake_batch

    samples = fake_batch(obs_space, action_space, batch_size=32)
    return {k: torch.from_numpy(v) for k, v in samples.items()}
Exemplo n.º 12
0
def batch(obs_space, action_space):
    samples = fake_batch(obs_space, action_space, batch_size=256)
    return {k: torch.from_numpy(v) for k, v in samples.items()}
Exemplo n.º 13
0
def cont_batch(obs_space, cont_space):
    samples = fake_batch(obs_space, cont_space, batch_size=32)
    return {k: torch.from_numpy(v) for k, v in samples.items()}