コード例 #1
0
ファイル: rl.py プロジェクト: stephenjfox/trax
def Quality(
    body=None,
    normalizer=None,
    batch_axes=None,
    mode='train',
    n_actions=2,
    head_init_range=None,
):
    """The network takes as input an observation and outputs values of actions."""

    if body is None:
        body = lambda mode: []
    if normalizer is None:
        normalizer = lambda mode: []

    head_kwargs = {}
    if head_init_range is not None:
        head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer(
            lim=head_init_range)

    return tl.Serial(
        _Batch(normalizer(mode=mode), batch_axes),
        _Batch(body(mode=mode), batch_axes),
        tl.Dense(n_actions, **head_kwargs),
    )
コード例 #2
0
ファイル: rl.py プロジェクト: weiddeng/trax
def PolicyAndValue(
    policy_distribution,
    body=None,
    policy_top=Policy,
    value_top=Value,
    normalizer=None,
    head_init_range=None,
    mode='train',
):
    """Attaches policy and value heads to a model body."""

    head_kwargs = {}
    if head_init_range is not None:
        head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer(
            lim=head_init_range)

    if normalizer is None:
        normalizer = lambda mode: []
    if body is None:
        body = lambda mode: []
    return tl.Serial(
        normalizer(mode=mode),
        body(mode=mode),
        tl.Branch(
            policy_top(policy_distribution=policy_distribution, mode=mode),
            value_top(mode=mode),
        ),
    )
コード例 #3
0
ファイル: rl.py プロジェクト: elliotthwang/trax
def Policy(
    policy_distribution,
    body=None,
    normalizer=None,
    head_init_range=None,
    batch_axes=None,
    mode='train',
):
    """Attaches a policy head to a model body."""
    if body is None:
        body = lambda mode: []
    if normalizer is None:
        normalizer = lambda mode: []
    if batch_axes is None:
        batch = lambda x: x
    else:
        batch = lambda x: tl.BatchLeadingAxes(x,
                                              n_last_axes_to_keep=batch_axes)

    head_kwargs = {}
    if head_init_range is not None:
        head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer(
            lim=head_init_range)

    return tl.Serial(
        batch(normalizer(mode=mode)),
        batch(body(mode=mode)),
        tl.Dense(policy_distribution.n_inputs, **head_kwargs),
    )
コード例 #4
0
def Policy(policy_distribution, body=None, head_init_range=None, mode='train'):
  """Attaches a policy head to a model body."""
  if body is None:
    body = lambda mode: []

  head_kwargs = {}
  if head_init_range is not None:
    head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer(
        lim=head_init_range
    )

  return tl.Serial(
      body(mode=mode),
      tl.Dense(policy_distribution.n_inputs, **head_kwargs),
  )
コード例 #5
0
ファイル: rl.py プロジェクト: stephenjfox/trax
def Value(
    body=None,
    normalizer=None,
    inject_actions=False,
    inject_actions_n_layers=1,
    inject_actions_dim=64,
    batch_axes=None,
    mode='train',
    is_discrete=False,
    vocab_size=2,
    multiplicative_action_injection=False,
    head_init_range=None,
):
    """Attaches a value head to a model body."""
    if body is None:
        body = lambda mode: []
    if normalizer is None:
        normalizer = lambda mode: []

    def ActionInjector(mode):
        if inject_actions:
            if is_discrete:
                action_encoder = tl.Embedding(vocab_size, inject_actions_dim)
            else:
                action_encoder = tl.Dense(inject_actions_dim)
            encoders = tl.Parallel(
                tl.Dense(inject_actions_dim),
                action_encoder,
            )
            if multiplicative_action_injection:
                action_injector = tl.Serial(
                    tl.Fn('TanhMulGate', lambda x, a: x * jnp.tanh(a)),
                    tl.LayerNorm()  # compensate for reduced variance
                )
            else:
                action_injector = tl.Add()
            return tl.Serial(
                # Input: (body output, actions).
                encoders,
                action_injector,
                models.MLP(
                    layer_widths=(inject_actions_dim, ) *
                    inject_actions_n_layers,
                    out_activation=True,
                    flatten=False,
                    mode=mode,
                ))
        else:
            return []

    head_kwargs = {}
    if head_init_range is not None:
        head_kwargs['kernel_initializer'] = tl.RandomUniformInitializer(
            lim=head_init_range)

    return tl.Serial(
        _Batch(normalizer(mode=mode), batch_axes),
        _Batch(body(mode=mode), batch_axes),
        ActionInjector(mode=mode),
        tl.Dense(1, **head_kwargs),
    )
コード例 #6
0
 def test_random_uniform(self):
   f = tl.RandomUniformInitializer()
   init_value = f(INPUT_SHAPE, rng())
   self.assertEqual(init_value.shape, INPUT_SHAPE)