def test_extract_inner_model(self): vocab_size = 3 inner_model = transformer.TransformerLM(vocab_size=vocab_size, d_model=2, d_ff=2, n_layers=0) obs_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) act_serializer = space_serializer.create(gym.spaces.Discrete(2), vocab_size=vocab_size) serialized_model = serialization_utils.SerializedModel( inner_model, observation_serializer=obs_serializer, action_serializer=act_serializer, significance_decay=0.9, ) obs_sig = shapes.ShapeDtype((1, 2)) act_sig = shapes.ShapeDtype((1, 1)) (weights, state) = serialized_model.init(input_signature=(obs_sig, act_sig, obs_sig, obs_sig), ) (inner_weights, inner_state) = map(serialization_utils.extract_inner_model, (weights, state)) inner_model(np.array([[0]]), weights=inner_weights, state=inner_state)
def _make_model_and_session(): m = transformer.TransformerLM( vocab_size, d_model=4, d_ff=4, n_layers=1, n_heads=2, dropout=0.) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts
def test_transformer_lm_forward_shape(self): vocab_size = 16 model = transformer.TransformerLM( vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) x = np.ones((3, 5)).astype(np.int32) _, _ = model.init(shapes.signature(x)) y = model(x) self.assertEqual(y.shape, (3, 5, vocab_size))
def test_transformer_lm_forward_shape(self): """Run the Transformer LM forward and check output shape.""" vocab_size = 16 input_shape = [3, 5] model = transformer.TransformerLM( vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) final_shape = tl.check_shape_agreement( model, tuple(input_shape), integer_inputs=True) self.assertEqual(tuple(input_shape + [vocab_size]), final_shape)
def test_transformer_lm_forward_shape(self): """Run the Transformer LM forward and check output shape.""" vocab_size = 16 input_signature = ShapeDtype((3, 5), onp.int32) model = transformer.TransformerLM(vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 5, vocab_size), final_shape)
maxval = 1 return rng.uniform( shape=shape, dtype=sig.dtype, minval=minval, maxval=maxval) return math_lib.nested_map(f, input_sig) def Mod(n): # pylint: disable=invalid-name return layers.Fn("Mod", lambda x: x % n) # Format: # (trax-layer maker, input shapes, input dtype, can handle None batch size?) _LAYERS = [ (lambda: layers.Dense(3), tf.TensorShape([4]), onp.float32, True), (mlp.MLP, tf.TensorShape([4]), onp.float32, False), (lambda: layers.Serial(Mod(8), transformer.TransformerLM(8)), tf.TensorShape([4]), onp.int32, False), ] _RNG_UPDATERS = [ lambda x: x, lambda rng: math_lib.random.split(rng, 1)[0], ] # Needs tf.test.TestCase for `assertAllClose` and `get_temp_dir` class Trax2KerasTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( [{"testcase_name": "_%s_%s_%s_%s_%s_%s" % ( # pylint: disable=g-complex-comprehension