def _make_network(action_spec: specs.DiscreteArray) -> snt.RNNCore: return snt.DeepRNN([ snt.Flatten(), snt.LSTM(20), snt.nets.MLP([50, 50]), networks.PolicyValueHead(action_spec.num_values), ])
def test_mcts(self): # Create a fake environment to test with. num_actions = 5 environment = fakes.DiscreteEnvironment(num_actions=num_actions, num_observations=10, obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) network = snt.Sequential([ snt.Flatten(), snt.nets.MLP([50, 50]), networks.PolicyValueHead(spec.actions.num_values), ]) model = simulator.Simulator(environment) optimizer = snt.optimizers.Adam(1e-3) # Construct the agent. agent = mcts.MCTS(environment_spec=spec, network=network, model=model, optimizer=optimizer, n_step=1, discount=1., replay_capacity=100, num_simulations=10, batch_size=10) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=2)
def __init__(self, action_spec: specs.DiscreteArray, name: Optional[Text] = None): super().__init__(name=name) # TODO: make a flags for hidden layer dims. self.flat = snt.nets.MLP([64, 64], name="mlp_1") self.rnn = snt.DeepRNN([ snt.nets.MLP([50, 50], activate_final=True, name="mlp_2"), snt.GRU(512, name="gru"), networks.PolicyValueHead(action_spec.num_values) ])
def __init__(self, action_spec: specs.DiscreteArray, name: Optional[Text] = None): super().__init__(name=name) # Spatial self.conv1 = snt.Conv2D(16, 1, 1, data_format="NHWC", name="conv_1") self.conv2 = snt.Conv2D(32, 3, 1, data_format="NHWC", name="conv_2") self.conv3 = snt.Conv2D(64, 3, 1, data_format="NHWC", name="conv_3") self.conv4 = snt.Conv2D(32, 3, 1, data_format="NHWC", name="conv_4") self.flatten = snt.Flatten() self.fc1 = snt.Linear(256, name="fc_1") # Flat self.flat = snt.nets.MLP([64, 64], name="mlp_1") self.rnn = snt.DeepRNN([ snt.nets.MLP([50, 50], activate_final=True, name="mlp_2"), snt.GRU(512, name="gru"), networks.PolicyValueHead(action_spec.num_values) ])
def make_network(action_spec: specs.DiscreteArray) -> snt.Module: return snt.Sequential([ snt.Flatten(), snt.nets.MLP([50, 50]), networks.PolicyValueHead(action_spec.num_values), ])