def _test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): model = transformer.Transformer( input_vocab_size, output_vocab_size, d_model=32, d_ff=64, n_encoder_layers=2, n_decoder_layers=2, n_heads=2) xs = [np.ones((3, 5)).astype(np.int32), np.ones((3, 5)).astype(np.int32)] _, _ = model.init(shapes.signature(xs)) y, _ = model(xs) vocab_size = output_vocab_size or input_vocab_size self.assertEqual(y.shape, (3, 5, vocab_size))
def _test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): """Run the Transformer forward and check output shape.""" single_input_shape = [3, 5] input_shape = (tuple(single_input_shape), tuple(single_input_shape)) model = transformer.Transformer( input_vocab_size, output_vocab_size, d_model=32, d_ff=64, n_encoder_layers=2, n_decoder_layers=2, n_heads=2) final_shape = tl.check_shape_agreement( model, input_shape, integer_inputs=True) expected_shape = (tuple(single_input_shape + [output_vocab_size if output_vocab_size is not None else input_vocab_size])) self.assertEqual(expected_shape, final_shape[0])
def _test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): """Run the Transformer forward and check output shape.""" input_sd = ShapeDtype((3, 5), onp.int32) input_signature = (input_sd, input_sd) model = transformer.Transformer(input_vocab_size, output_vocab_size, d_model=32, d_ff=64, n_encoder_layers=2, n_decoder_layers=2, n_heads=2) final_shape = tl.check_shape_agreement(model, input_signature) vocab_size = output_vocab_size or input_vocab_size expected_shape = (3, 5, vocab_size) self.assertEqual(expected_shape, final_shape[0])