def test_transformer_lm_forward_shape(self): """Run the Transformer LM forward and check output shape.""" vocab_size = 16 input_shape = [3, 5] model = transformer.TransformerLM( vocab_size, d_feature=32, d_feedforward=64, n_layers=2, n_heads=2) final_shape = tl.check_shape_agreement( model, tuple(input_shape), integer_inputs=True) self.assertEqual(tuple(input_shape + [vocab_size]), final_shape)
def _test_transformer_forward_shape(self, input_vocab_size, output_vocab_size): """Run the Transformer forward and check output shape.""" single_input_shape = [3, 5] input_shape = (tuple(single_input_shape), tuple(single_input_shape)) model = transformer.Transformer( input_vocab_size, output_vocab_size, d_feature=32, d_feedforward=64, n_layers=2, n_heads=2) final_shape = tl.check_shape_agreement( model, input_shape, integer_inputs=True) expected_shape = (tuple(single_input_shape + [output_vocab_size if output_vocab_size is not None else input_vocab_size])) self.assertEqual(expected_shape, final_shape)
def test_reformer_lm_forward_shape(self): """Run the ReformerLM forward and check output shape.""" vocab_size = 16 input_shape = ((1, 8), (1, 8)) model = reformer.ReformerLM(vocab_size, d_model=32, d_ff=64, d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, max_len=16, n_chunks=2, n_attention_chunks=1) final_shape = tl.check_shape_agreement(model, tuple(input_shape), integer_inputs=True) self.assertEqual(((1, 8, 16), (1, 8, 16)), final_shape)
def test_wide_resnet(self): input_shape = (3, 32, 32, 3) model = resnet.WideResnet(n_blocks=1, n_output_classes=10) final_shape = tl.check_shape_agreement(tl.Serial(model), input_shape) self.assertEqual((3, 10), final_shape)
def test_resnet(self): input_shape = (3, 256, 256, 3) model = resnet.Resnet50(d_hidden=8, n_output_classes=10) final_shape = tl.check_shape_agreement(tl.Serial(model), input_shape) self.assertEqual((3, 10), final_shape)
def test_mlp_forward_shape(self): """Run the MLP model forward and check output shape.""" input_shape = (3, 28, 28, 1) model = mlp.MLP(d_hidden=32, n_output_classes=10) final_shape = tl.check_shape_agreement(model, input_shape) self.assertEqual((3, 10), final_shape)