예제 #1
0
파일: rl.py 프로젝트: stephenjfox/trax
 def ActionInjector(mode):
     if inject_actions:
         if is_discrete:
             action_encoder = tl.Embedding(vocab_size, inject_actions_dim)
         else:
             action_encoder = tl.Dense(inject_actions_dim)
         encoders = tl.Parallel(
             tl.Dense(inject_actions_dim),
             action_encoder,
         )
         if multiplicative_action_injection:
             action_injector = tl.Serial(
                 tl.Fn('TanhMulGate', lambda x, a: x * jnp.tanh(a)),
                 tl.LayerNorm()  # compensate for reduced variance
             )
         else:
             action_injector = tl.Add()
         return tl.Serial(
             # Input: (body output, actions).
             encoders,
             action_injector,
             models.MLP(
                 layer_widths=(inject_actions_dim, ) *
                 inject_actions_n_layers,
                 out_activation=True,
                 flatten=False,
                 mode=mode,
             ))
     else:
         return []
예제 #2
0
 def model_fn(mode='train'):
     return tl.Serial(
         tl.Dropout(mode=mode, rate=0.1),
         tl.BatchNorm(mode=mode),
         models.MLP(d_hidden=16,
                    n_output_classes=n_classes,
                    mode=mode))
예제 #3
0
 def model_fn(mode='train'):
     return tl.Serial(
         tl.Dropout(mode=mode, rate=0.1),
         tl.BatchNorm(mode=mode),
         models.MLP(layer_widths=(16, 16, n_classes),
                    mode=mode))
예제 #4
0
 def model_fn(mode="train"):
   return layers.Model(
       layers.Dropout(mode=mode, rate=0.1), layers.BatchNorm(mode=mode),
       models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode))