コード例 #1
0
    def test_mlp_input_signatures(self):
        mlp_block = mlp.MLP(d_hidden=32, n_output_classes=10)
        relu = tl.Relu()
        mlp_and_relu = tl.Serial(mlp_block, relu)

        # Check for correct shapes entering and exiting the mlp_block.
        mlp_and_relu.input_signature = ShapeDtype((3, 28, 28, 1))
        self.assertEqual(mlp_block.input_signature, ShapeDtype((3, 28, 28, 1)))
        self.assertEqual(relu.input_signature, ShapeDtype((3, 10)))
コード例 #2
0
ファイル: mlp_test.py プロジェクト: qsong4/trax
    def test_mlp_input_signatures(self):
        mlp_layer = mlp.MLP(d_hidden=32, n_output_classes=10)
        relu = tl.Relu()
        mlp_and_relu = tl.Serial(
            mlp_layer,
            relu,
        )
        x = np.ones((3, 28, 28, 1)).astype(np.float32)
        input_signature = shapes.signature(x)
        mlp_and_relu.init(input_signature)

        # Check for correct shapes entering and exiting the mlp_block.
        mlp_and_relu._set_input_signature_recursive(input_signature)
        self.assertEqual(mlp_layer.input_signature, input_signature)
        self.assertEqual(relu.input_signature, shapes.ShapeDtype((3, 10)))
コード例 #3
0
ファイル: mlp_test.py プロジェクト: zzszmyf/trax
 def test_mlp_forward_shape(self):
     """Run the MLP model forward and check output shape."""
     input_signature = ShapeDtype((3, 28, 28, 1))
     model = mlp.MLP(d_hidden=32, n_output_classes=10)
     final_shape = tl.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 10), final_shape)
コード例 #4
0
 def test_mlp_forward_shape(self):
     model = mlp.MLP(layer_widths=(32, 16, 8))
     x = np.ones((7, 28, 28, 3)).astype(np.float32)
     _, _ = model.init(shapes.signature(x))
     y = model(x)
     self.assertEqual(y.shape, (7, 8))
コード例 #5
0
ファイル: mlp_test.py プロジェクト: qsong4/trax
 def test_mlp_forward_shape(self):
     model = mlp.MLP(d_hidden=32, n_output_classes=10)
     x = np.ones((3, 28, 28, 1)).astype(np.float32)
     _, _ = model.init(shapes.signature(x))
     y = model(x)
     self.assertEqual(y.shape, (3, 10))