Exemplo n.º 1
0
 def test(env, base, batch_size, nested, rollout_length):
     pi = Policy(base(env.observation_space, env.action_space))
     if nested:
         env = NestedVecObWrapper(env)
     nenv = env.num_envs
     data_manager = RolloutDataManager(env,
                                       RolloutActor(pi),
                                       'cpu',
                                       batch_size=batch_size,
                                       rollout_length=rollout_length)
     for _ in range(3):
         data_manager.rollout()
         count = 0
         for batch in data_manager.sampler():
             assert 'key1' in batch
             count += 1
             assert 'done' in batch
             data_manager.act(batch['obs'])
         if data_manager.recurrent:
             assert count == np.ceil(nenv / data_manager.batch_size)
         else:
             n = data_manager.storage.get_rollout()['reward'].data.shape[0]
             assert count == np.ceil(n / data_manager.batch_size)
Exemplo n.º 2
0
        def test_policy_base(self):
            """Test Policy base."""
            class Base(PolicyBase):
                def forward(self, ob):
                    return DeltaDist(torch.ones([2], dtype=torch.float32))

            env = gym.make('LunarLanderContinuous-v2')
            pi = Policy(Base(env.observation_space, env.action_space))
            pi2 = UnnormActionPolicy(
                Base(env.observation_space, env.action_space))
            ob = env.reset()

            outs = pi(ob[None])
            assert outs.value is None
            assert outs.state_out is None

            outs2 = pi2(ob[None])
            assert outs.value is None
            assert outs.state_out is None

            ac_normed = 2 * (outs2.action - pi2.low) / (pi2.high -
                                                        pi2.low) - 1.
            assert torch.allclose(ac_normed, outs.action)
            assert torch.allclose(outs2.normed_action, outs.action)
Exemplo n.º 3
0
 def policy_fn(env):
     return Policy(
         FeedForwardActorCriticBase(env.observation_space,
                                    env.action_space))
Exemplo n.º 4
0
 def policy_fn(env):
     """Create a policy."""
     return Policy(PiBase(env.observation_space, env.action_space))
Exemplo n.º 5
0
Arquivo: base.py Projeto: amackeith/dl
def policy_fn(env):
    """Create a policy network."""
    return Policy(
        FeedForwardPolicyBase(env.observation_space, env.action_space))
Exemplo n.º 6
0
Arquivo: base.py Projeto: amackeith/dl
def continuous_policy_fn(env):
    """Create policy."""
    return Policy(ContinuousPolicyBase(env.observation_space,
                                       env.action_space))
Exemplo n.º 7
0
Arquivo: base.py Projeto: amackeith/dl
def discrete_policy_fn(env):
    """Create policy."""
    return Policy(DiscretePolicyBase(env.observation_space, env.action_space))
Exemplo n.º 8
0
def drone_ppo_policy_fn(env, nunits=128):
    """Create a policy network."""
    return Policy(
        FeedForwardActorCriticBase(env.observation_space,
                                   env.action_space,
                                   nunits=nunits))
Exemplo n.º 9
0
 def policy_fn(env):
     return Policy(NatureDQN(env.observation_space,
                             env.action_space))
Exemplo n.º 10
0
def policy_fn(env):
    """Create policy."""
    return Policy(NatureDQN(env.observation_space, env.action_space))
Exemplo n.º 11
0
 def policy_fn():
     return Policy(TicTacToeNet())
Exemplo n.º 12
0
def a3c_rnn_fn(env):
    """Create a3c recurrent policy."""
    return Policy(A3CRNN(env.observation_space, env.action_space))
Exemplo n.º 13
0
def a3c_cnn_fn(env):
    """Create a3c conv net policy."""
    return Policy(A3CCNN(env.observation_space, env.action_space))