def Quality( body=None, normalizer=None, batch_axes=None, mode='train', n_actions=2, head_init_range=None, ): """The network takes as input an observation and outputs values of actions.""" if body is None: body = lambda mode: [] if normalizer is None: normalizer = lambda mode: [] head_kwargs = {} if head_init_range is not None: head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer( lim=head_init_range) return tl.Serial( _Batch(normalizer(mode=mode), batch_axes), _Batch(body(mode=mode), batch_axes), tl.Dense(n_actions, **head_kwargs), )
def PolicyAndValue( policy_distribution, body=None, policy_top=Policy, value_top=Value, normalizer=None, head_init_range=None, mode='train', ): """Attaches policy and value heads to a model body.""" head_kwargs = {} if head_init_range is not None: head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer( lim=head_init_range) if normalizer is None: normalizer = lambda mode: [] if body is None: body = lambda mode: [] return tl.Serial( normalizer(mode=mode), body(mode=mode), tl.Branch( policy_top(policy_distribution=policy_distribution, mode=mode), value_top(mode=mode), ), )
def Policy( policy_distribution, body=None, normalizer=None, head_init_range=None, batch_axes=None, mode='train', ): """Attaches a policy head to a model body.""" if body is None: body = lambda mode: [] if normalizer is None: normalizer = lambda mode: [] if batch_axes is None: batch = lambda x: x else: batch = lambda x: tl.BatchLeadingAxes(x, n_last_axes_to_keep=batch_axes) head_kwargs = {} if head_init_range is not None: head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer( lim=head_init_range) return tl.Serial( batch(normalizer(mode=mode)), batch(body(mode=mode)), tl.Dense(policy_distribution.n_inputs, **head_kwargs), )
def Policy(policy_distribution, body=None, head_init_range=None, mode='train'): """Attaches a policy head to a model body.""" if body is None: body = lambda mode: [] head_kwargs = {} if head_init_range is not None: head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer( lim=head_init_range ) return tl.Serial( body(mode=mode), tl.Dense(policy_distribution.n_inputs, **head_kwargs), )
def Value( body=None, normalizer=None, inject_actions=False, inject_actions_n_layers=1, inject_actions_dim=64, batch_axes=None, mode='train', is_discrete=False, vocab_size=2, multiplicative_action_injection=False, head_init_range=None, ): """Attaches a value head to a model body.""" if body is None: body = lambda mode: [] if normalizer is None: normalizer = lambda mode: [] def ActionInjector(mode): if inject_actions: if is_discrete: action_encoder = tl.Embedding(vocab_size, inject_actions_dim) else: action_encoder = tl.Dense(inject_actions_dim) encoders = tl.Parallel( tl.Dense(inject_actions_dim), action_encoder, ) if multiplicative_action_injection: action_injector = tl.Serial( tl.Fn('TanhMulGate', lambda x, a: x * jnp.tanh(a)), tl.LayerNorm() # compensate for reduced variance ) else: action_injector = tl.Add() return tl.Serial( # Input: (body output, actions). encoders, action_injector, models.MLP( layer_widths=(inject_actions_dim, ) * inject_actions_n_layers, out_activation=True, flatten=False, mode=mode, )) else: return [] head_kwargs = {} if head_init_range is not None: head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer( lim=head_init_range) return tl.Serial( _Batch(normalizer(mode=mode), batch_axes), _Batch(body(mode=mode), batch_axes), ActionInjector(mode=mode), tl.Dense(1, **head_kwargs), )
def test_random_uniform(self): f = tl.RandomUniformInitializer() init_value = f(INPUT_SHAPE, rng()) self.assertEqual(init_value.shape, INPUT_SHAPE)