def test_reformer_lm_forward_shape(self): """Run the ReformerLM forward and check output shape.""" vocab_size = 16 input_sd = ShapeDtype((1, 8), np.int32) input_signature = (input_sd, input_sd) model = reformer.ReformerLM(vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16, n_chunks=2, n_attention_chunks=1) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual(((1, 8, 16), (1, 8, 16)), final_shape)
def _test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): """Run the Transformer forward and check output shape.""" single_input_shape = [3, 5] input_shape = (tuple(single_input_shape), tuple(single_input_shape)) model = transformer.Transformer(input_vocab_size, output_vocab_size, d_model=32, d_ff=64, n_layers=2, n_heads=2) final_shape = tl.check_shape_agreement(model, input_shape, integer_inputs=True) expected_shape = (tuple(single_input_shape + [ output_vocab_size if output_vocab_size is not None else input_vocab_size ])) self.assertEqual(expected_shape, final_shape[0])
def test_rnnlm_forward_shape(self): """Runs the RNN LM forward and checks output shape.""" input_signature = ShapeDtype((3, 28), dtype=math.numpy.int32) model = rnn.RNNLM(vocab_size=20, d_model=16) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 28, 20), final_shape)
def test_mlp_forward_shape(self): """Run the MLP model forward and check output shape.""" input_signature = ShapeDtype((3, 28, 28, 1)) model = mlp.MLP(d_hidden=32, n_output_classes=10) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 10), final_shape)
def test_wide_resnet(self): input_shape = (3, 32, 32, 3) model = resnet.WideResnet(n_blocks=1, n_output_classes=10) final_shape = tl.check_shape_agreement(model, input_shape) self.assertEqual((3, 10), final_shape)
def test_resnet(self): input_shape = (3, 256, 256, 3) model = resnet.Resnet50(d_hidden=8, n_output_classes=10) final_shape = tl.check_shape_agreement(model, input_shape) self.assertEqual((3, 10), final_shape)
def test_pure_mlp_forward_shape(self): """Run the PureMLP model forward and check output shape.""" input_signature = ShapeDtype((7, 28, 28, 3)) model = mlp.PureMLP(hidden_dims=(32, 16, 8)) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((7, 8), final_shape)
def test_image_fec(self): input_signature = ShapeDtype((3, 256, 256, 3)) model = ImageFEC(d_hidden=8, n_output_classes=10) final_shape = tl.check_shape_agreement(model, input_signature) self.assertEqual((3, 10), final_shape)