Exemple #1
0
def _make_network(action_spec: specs.DiscreteArray) -> snt.RNNCore:
    return snt.DeepRNN([
        snt.Flatten(),
        snt.LSTM(20),
        snt.nets.MLP([50, 50]),
        networks.PolicyValueHead(action_spec.num_values),
    ])
Exemple #2
0
    def test_mcts(self):
        # Create a fake environment to test with.
        num_actions = 5
        environment = fakes.DiscreteEnvironment(num_actions=num_actions,
                                                num_observations=10,
                                                obs_dtype=np.float32,
                                                episode_length=10)
        spec = specs.make_environment_spec(environment)

        network = snt.Sequential([
            snt.Flatten(),
            snt.nets.MLP([50, 50]),
            networks.PolicyValueHead(spec.actions.num_values),
        ])
        model = simulator.Simulator(environment)
        optimizer = snt.optimizers.Adam(1e-3)

        # Construct the agent.
        agent = mcts.MCTS(environment_spec=spec,
                          network=network,
                          model=model,
                          optimizer=optimizer,
                          n_step=1,
                          discount=1.,
                          replay_capacity=100,
                          num_simulations=10,
                          batch_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
Exemple #3
0
 def __init__(self,
              action_spec: specs.DiscreteArray,
              name: Optional[Text] = None):
     super().__init__(name=name)
     # TODO: make a flags for hidden layer dims.
     self.flat = snt.nets.MLP([64, 64], name="mlp_1")
     self.rnn = snt.DeepRNN([
         snt.nets.MLP([50, 50], activate_final=True, name="mlp_2"),
         snt.GRU(512, name="gru"),
         networks.PolicyValueHead(action_spec.num_values)
     ])
Exemple #4
0
    def __init__(self,
                 action_spec: specs.DiscreteArray,
                 name: Optional[Text] = None):
        super().__init__(name=name)

        # Spatial
        self.conv1 = snt.Conv2D(16, 1, 1, data_format="NHWC", name="conv_1")
        self.conv2 = snt.Conv2D(32, 3, 1, data_format="NHWC", name="conv_2")
        self.conv3 = snt.Conv2D(64, 3, 1, data_format="NHWC", name="conv_3")
        self.conv4 = snt.Conv2D(32, 3, 1, data_format="NHWC", name="conv_4")
        self.flatten = snt.Flatten()

        self.fc1 = snt.Linear(256, name="fc_1")

        # Flat
        self.flat = snt.nets.MLP([64, 64], name="mlp_1")
        self.rnn = snt.DeepRNN([
            snt.nets.MLP([50, 50], activate_final=True, name="mlp_2"),
            snt.GRU(512, name="gru"),
            networks.PolicyValueHead(action_spec.num_values)
        ])
Exemple #5
0
def make_network(action_spec: specs.DiscreteArray) -> snt.Module:
    return snt.Sequential([
        snt.Flatten(),
        snt.nets.MLP([50, 50]),
        networks.PolicyValueHead(action_spec.num_values),
    ])