コード例 #1
0
 def __call__(self, x):
   return hk.Sequential([hk.Flatten(), hk.Linear(3), jax.nn.relu])(x)
コード例 #2
0
 def network(inputs: jnp.ndarray) -> jnp.ndarray:
     flat_inputs = hk.Flatten()(inputs)
     mlp = hk.nets.MLP([64, 64, action_spec.num_values])
     action_values = mlp(flat_inputs)
     return action_values
コード例 #3
0
def _flatten(x, num_batch_dims: int):
    return hk.Flatten(preserve_dims=num_batch_dims)(x)
コード例 #4
0
ファイル: simple_dqn.py プロジェクト: deepmind/rlax
 def q(obs):
     network = hk.Sequential(
         [hk.Flatten(),
          nets.MLP([FLAGS.hidden_units, num_actions])])
     return network(obs)
コード例 #5
0
ファイル: agent_test.py プロジェクト: whl19910402/acme
 def network(x):
   model = hk.Sequential([
       hk.Flatten(),
       hk.nets.MLP([50, 50, spec.actions.num_values])
   ])
   return model(x)
コード例 #6
0
 def network(inputs: jnp.ndarray) -> jnp.ndarray:
     """Simple Q-network with randomized prior function."""
     net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
     prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
     x = hk.Flatten()(inputs)
     return net(x) + prior_scale * lax.stop_gradient(prior_net(x))
コード例 #7
0
ファイル: transformed_test.py プロジェクト: deepmind/distrax
 def forward(x):
   network = hk.Sequential([hk.Flatten(preserve_dims=1), hk.Linear(3)])
   return network(x)
コード例 #8
0
ファイル: actors_test.py プロジェクト: shadowkun/acme
 def network(inputs: jnp.ndarray, state: hk.LSTMState):
     return hk.DeepRNN([hk.Flatten(),
                        hk.LSTM(output_size)])(inputs, state)
コード例 #9
0
ファイル: actors_test.py プロジェクト: shadowkun/acme
 def initial_state(batch_size: int):
     network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])
     return network.initial_state(batch_size)
コード例 #10
0
                     shape=(BATCH_SIZE, 2, 2)),
    ModuleDescriptor(name="nets.MLP",
                     create=lambda: hk.nets.MLP([3, 4, 5]),
                     shape=(BATCH_SIZE, 3)),
)

# Modules that require input to have a batch dimension.
BATCH_MODULES = (
    ModuleDescriptor(name="BatchNorm",
                     create=lambda: Training(hk.BatchNorm(True, True, 0.9)),
                     shape=(BATCH_SIZE, 2, 2, 3)),
    ModuleDescriptor(name="Bias",
                     create=lambda: hk.Bias(),
                     shape=(BATCH_SIZE, 3, 3, 3)),
    ModuleDescriptor(name="Flatten",
                     create=lambda: hk.Flatten(),
                     shape=(BATCH_SIZE, 3, 3, 3)),
    ModuleDescriptor(name="InstanceNorm",
                     create=lambda: hk.InstanceNorm(True, True),
                     shape=(BATCH_SIZE, 3, 2)),
    ModuleDescriptor(name="LayerNorm",
                     create=lambda: hk.LayerNorm(1, True, True),
                     shape=(BATCH_SIZE, 3, 2)),
    ModuleDescriptor(name="SpectralNorm",
                     create=lambda: hk.SpectralNorm(),
                     shape=(BATCH_SIZE, 3, 2)),
    ModuleDescriptor(name="nets.ResNet",
                     create=lambda: Training(hk.nets.ResNet(
                         (3, 4, 6, 3), 1000)),
                     shape=(BATCH_SIZE, 3, 3, 2)),
    # pylint: disable=g-long-lambda
コード例 #11
0
ファイル: actors_test.py プロジェクト: shadowkun/acme
 def policy(inputs: jnp.ndarray):
     return hk.Sequential([
         hk.Flatten(),
         hk.Linear(env_spec.actions.num_values),
         lambda x: jnp.argmax(x, axis=-1),
     ])(inputs)
コード例 #12
0
def func(S, A, is_training):
    logits = hk.Sequential((hk.Flatten(), hk.Linear(20, w_init=jnp.zeros)))
    X = jax.vmap(jnp.kron)(S, A)  # S and A are one-hot encoded
    return {'logits': logits(X)}
コード例 #13
0
ファイル: lookahead_mnist.py プロジェクト: n2cholas/optax
 def mlp_model(inputs: jnp.ndarray) -> jnp.ndarray:
   flattened = hk.Flatten()(inputs)
   return hk.nets.MLP(layer_dims)(flattened)
コード例 #14
0
 def forward_pass(batch):
   network = hk.Sequential([
       hk.Flatten(),
       hk.Linear(num_classes),
   ])
   return network(batch['x'])
コード例 #15
0
def func(S, A, is_training):
    value = hk.Sequential(
        (hk.Flatten(), hk.Linear(1, w_init=jnp.zeros), jnp.ravel))
    X = jax.vmap(jnp.kron)(S, A)  # S and A are one-hot encoded
    return value(X)
コード例 #16
0
ファイル: run_bc_jax.py プロジェクト: pchtsp/acme
 def network(x: jnp.DeviceArray) -> jnp.DeviceArray:
     mlp = hk.Sequential(
         [hk.Flatten(),
          hk.nets.MLP([64, 64, action_spec.num_values])])
     return mlp(x)