Exemplo n.º 1
0
 def __call__(self, x):
   return hk.Sequential([hk.Flatten(), hk.Linear(3), jax.nn.relu])(x)
Exemplo n.º 2
0
 def network(inputs: jnp.ndarray) -> jnp.ndarray:
     flat_inputs = hk.Flatten()(inputs)
     mlp = hk.nets.MLP([64, 64, action_spec.num_values])
     action_values = mlp(flat_inputs)
     return action_values
Exemplo n.º 3
0
def _flatten(x, num_batch_dims: int):
    return hk.Flatten(preserve_dims=num_batch_dims)(x)
Exemplo n.º 4
0
 def q(obs):
     network = hk.Sequential(
         [hk.Flatten(),
          nets.MLP([FLAGS.hidden_units, num_actions])])
     return network(obs)
Exemplo n.º 5
0
 def network(x):
   model = hk.Sequential([
       hk.Flatten(),
       hk.nets.MLP([50, 50, spec.actions.num_values])
   ])
   return model(x)
Exemplo n.º 6
0
 def network(inputs: jnp.ndarray) -> jnp.ndarray:
     """Simple Q-network with randomized prior function."""
     net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
     prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
     x = hk.Flatten()(inputs)
     return net(x) + prior_scale * lax.stop_gradient(prior_net(x))
Exemplo n.º 7
0
 def forward(x):
   network = hk.Sequential([hk.Flatten(preserve_dims=1), hk.Linear(3)])
   return network(x)
Exemplo n.º 8
0
 def network(inputs: jnp.ndarray, state: hk.LSTMState):
     return hk.DeepRNN([hk.Flatten(),
                        hk.LSTM(output_size)])(inputs, state)
Exemplo n.º 9
0
 def initial_state(batch_size: int):
     network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
     return network.initial_state(batch_size)
Exemplo n.º 10
0
                     shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(name="nets.MLP",
                     create=lambda: hk.nets.MLP([3, 4, 5]),
                     shape=(BATCH_SIZE, 3)),
)

# Modules that require input to have a batch dimension.
BATCH_MODULES = (
    ModuleDescriptor(name="BatchNorm",
                     create=lambda: Training(hk.BatchNorm(True, True, 0.9)),
                     shape=(BATCH_SIZE, 2, 2, 3)),
    ModuleDescriptor(name="Bias",
                     create=lambda: hk.Bias(),
                     shape=(BATCH_SIZE, 3, 3, 3)),
    ModuleDescriptor(name="Flatten",
                     create=lambda: hk.Flatten(),
                     shape=(BATCH_SIZE, 3, 3, 3)),
    ModuleDescriptor(name="InstanceNorm",
                     create=lambda: hk.InstanceNorm(True, True),
                     shape=(BATCH_SIZE, 3, 2)),
    ModuleDescriptor(name="LayerNorm",
                     create=lambda: hk.LayerNorm(1, True, True),
                     shape=(BATCH_SIZE, 3, 2)),
    ModuleDescriptor(name="SpectralNorm",
                     create=lambda: hk.SpectralNorm(),
                     shape=(BATCH_SIZE, 3, 2)),
    ModuleDescriptor(name="nets.ResNet",
                     create=lambda: Training(hk.nets.ResNet(
                         (3, 4, 6, 3), 1000)),
                     shape=(BATCH_SIZE, 3, 3, 2)),
    # pylint: disable=g-long-lambda
Exemplo n.º 11
0
 def policy(inputs: jnp.ndarray):
     return hk.Sequential([
         hk.Flatten(),
         hk.Linear(env_spec.actions.num_values),
         lambda x: jnp.argmax(x, axis=-1),
     ])(inputs)
Exemplo n.º 12
0
def func(S, A, is_training):
    logits = hk.Sequential((hk.Flatten(), hk.Linear(20, w_init=jnp.zeros)))
    X = jax.vmap(jnp.kron)(S, A)  # S and A are one-hot encoded
    return {'logits': logits(X)}
Exemplo n.º 13
0
 def mlp_model(inputs: jnp.ndarray) -> jnp.ndarray:
   flattened = hk.Flatten()(inputs)
   return hk.nets.MLP(layer_dims)(flattened)
Exemplo n.º 14
0
 def forward_pass(batch):
   network = hk.Sequential([
       hk.Flatten(),
       hk.Linear(num_classes),
   ])
   return network(batch['x'])
Exemplo n.º 15
0
def func(S, A, is_training):
    value = hk.Sequential(
        (hk.Flatten(), hk.Linear(1, w_init=jnp.zeros), jnp.ravel))
    X = jax.vmap(jnp.kron)(S, A)  # S and A are one-hot encoded
    return value(X)
Exemplo n.º 16
0
 def network(x: jnp.DeviceArray) -> jnp.DeviceArray:
     mlp = hk.Sequential(
         [hk.Flatten(),
          hk.nets.MLP([64, 64, action_spec.num_values])])
     return mlp(x)