Exemple #1
0
    def test_extract_inner_model(self):
        vocab_size = 3

        inner_model = transformer.TransformerLM(vocab_size=vocab_size,
                                                d_model=2,
                                                d_ff=2,
                                                n_layers=0)
        obs_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=vocab_size)
        serialized_model = serialization_utils.SerializedModel(
            inner_model,
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )

        obs_sig = shapes.ShapeDtype((1, 2))
        act_sig = shapes.ShapeDtype((1, 1))
        (weights,
         state) = serialized_model.init(input_signature=(obs_sig, act_sig,
                                                         obs_sig, obs_sig), )
        (inner_weights,
         inner_state) = map(serialization_utils.extract_inner_model,
                            (weights, state))
        inner_model(np.array([[0]]), weights=inner_weights, state=inner_state)
Exemple #2
0
 def _make_model_and_session():
   m = transformer.TransformerLM(
       vocab_size, d_model=4, d_ff=4, n_layers=1, n_heads=2, dropout=0.)
   ts = training.Loop(m, [task], eval_tasks=[eval_task],
                      eval_at=lambda step_n: step_n % 2 == 0,
                      output_dir=tmp_dir)
   return m, ts
Exemple #3
0
 def test_transformer_lm_forward_shape(self):
   vocab_size = 16
   model = transformer.TransformerLM(
       vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2)
   x = np.ones((3, 5)).astype(np.int32)
   _, _ = model.init(shapes.signature(x))
   y = model(x)
   self.assertEqual(y.shape, (3, 5, vocab_size))
Exemple #4
0
 def test_transformer_lm_forward_shape(self):
   """Run the Transformer LM forward and check output shape."""
   vocab_size = 16
   input_shape = [3, 5]
   model = transformer.TransformerLM(
       vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2)
   final_shape = tl.check_shape_agreement(
       model, tuple(input_shape), integer_inputs=True)
   self.assertEqual(tuple(input_shape + [vocab_size]), final_shape)
Exemple #5
0
 def test_transformer_lm_forward_shape(self):
     """Run the Transformer LM forward and check output shape."""
     vocab_size = 16
     input_signature = ShapeDtype((3, 5), onp.int32)
     model = transformer.TransformerLM(vocab_size,
                                       d_model=32,
                                       d_ff=64,
                                       n_layers=2,
                                       n_heads=2)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 5, vocab_size), final_shape)
Exemple #6
0
      maxval = 1
    return rng.uniform(
        shape=shape, dtype=sig.dtype, minval=minval, maxval=maxval)
  return math_lib.nested_map(f, input_sig)


def Mod(n):  # pylint: disable=invalid-name
  return layers.Fn("Mod", lambda x: x % n)


# Format:
#   (trax-layer maker, input shapes, input dtype, can handle None batch size?)
_LAYERS = [
    (lambda: layers.Dense(3), tf.TensorShape([4]), onp.float32, True),
    (mlp.MLP, tf.TensorShape([4]), onp.float32, False),
    (lambda: layers.Serial(Mod(8), transformer.TransformerLM(8)),
     tf.TensorShape([4]), onp.int32, False),
]


_RNG_UPDATERS = [
    lambda x: x,
    lambda rng: math_lib.random.split(rng, 1)[0],
]


# Needs tf.test.TestCase for `assertAllClose` and `get_temp_dir`
class Trax2KerasTest(tf.test.TestCase, parameterized.TestCase):

  @parameterized.named_parameters(
      [{"testcase_name": "_%s_%s_%s_%s_%s_%s" % (  # pylint: disable=g-complex-comprehension