Пример #1
0
 def test_get_phi(self):
     batch_size = 3
     K = 32
     N = 8
     state_shape = [1]
     inner_size = 256
     mlp = nn.Sequential(nn.Linear(inner_size, inner_size), nn.ReLU())
     model = IQNModel(
         dims=state_shape,
         num_actions=3,
         perception_net=mlp,
         inner_size=inner_size,
         default_samples=K)
     phi, tau = model.get_phi(batch_size, N)
     self.assertEqual((batch_size, N, inner_size), phi.size())
     self.assertEqual((batch_size, N), tau.size())
Пример #2
0
    def test_value(self):
        batch_size = 5
        num_actions = 3
        state_shape = [10]
        inner_size = 256
        K = 32
        N = 8
        state = None
        mlp = nn.Sequential(nn.Linear(state_shape[0], inner_size), nn.ReLU())
        dm = DummyInput(batch_size, state_shape[0])

        model = IQNModel(
            dims=state_shape,
            num_actions=num_actions,
            perception_net=mlp,
            inner_size=inner_size,
            default_samples=K)
        value, a_state = model.value(dm, state, N)
        self.assertEqual(state, a_state)
        self.assertEqual((batch_size, num_actions, N),
                         value["q_value_distribution"].size())
 def initialize(self):
     inner_size = 256
     num_actions = 3
     state_shape = [1]
     mlp = nn.Sequential(nn.Linear(inner_size, inner_size), nn.ReLU())
     model = IQNModel(
         dims=state_shape,
         num_actions=num_actions,
         perception_net=mlp,
         inner_size=inner_size)
     alg = IQN(model=model,
               exploration_end_steps=500000,
               update_ref_interval=100)
     return alg
def iqn(cnn, dims, num_actions, num_agents):
    alg = IQN(model=IQNModel(dims=dims,
                             num_actions=num_actions,
                             perception_net=cnn),
              gpu_id=0,
              exploration_end_steps=500000 // num_agents,
              update_ref_interval=100,
              grad_clip=5.0)

    ct_settings = {
        "RL":
        dict(
            alg=alg,
            # sampling
            agent_helper=ExpReplayHelper,
            buffer_capacity=200000 // num_agents,
            num_experiences=4,  # num per agent
            num_seqs=0,  # sample instances
            sample_interval=5)
    }
    return ct_settings