def _build_and_test_network(input_size, transpose_layers,
                                    expected_layers):
            """
            Helper function for testing transpose removal.

            Args:
                input_size: Size of the input network tensor.
                transpose_layers: Array of transpose axes definitions.
                expected_layers: Array of indices into transpose_layers indicating
                    which of the transpose layers should be present after the
                    graph pass.
            """
            input_features = [("data", datatypes.Array(*input_size))]
            output_features = [("out", None)]
            builder = neural_network.NeuralNetworkBuilder(
                input_features, output_features)
            spec = builder.spec.neuralNetwork.layers

            last_layer = "data"
            for idx, axes in enumerate(transpose_layers):
                name = "t{}".format(idx)
                if idx == len(transpose_layers) - 1:
                    output_name = "out"
                else:
                    output_name = name + "_out"
                builder.add_transpose(name=name,
                                      axes=axes,
                                      input_name=last_layer,
                                      output_name=output_name)
                last_layer = output_name

            spec = builder.spec.neuralNetwork
            # Check the network before the graph pass.
            for idx in range(len(transpose_layers)):
                np.testing.assert_equal("transpose",
                                        spec.layers[idx].WhichOneof("layer"))
            # Run the removal pass.
            remove_redundant_transposes(builder.spec)
            # Verify only the expected layers remain.
            np.testing.assert_equal(len(spec.layers), len(expected_layers))
            for output_layer_idx, input_layer_idx in enumerate(
                    expected_layers):
                np.testing.assert_equal(
                    "transpose",
                    spec.layers[output_layer_idx].WhichOneof("layer"))
                np.testing.assert_array_equal(
                    transpose_layers[input_layer_idx],
                    spec.layers[output_layer_idx].transpose.axes,
                )
    def _test_builder(self, builder, input_shape, expected_layer_num=None):

        data = np.random.rand(*input_shape)

        # Mlmodel before
        mlmodel = MLModel(builder.spec)
        output_before = mlmodel.predict({"data": data})["out"]
        num_layers_before = len(builder.spec.neuralNetwork.layers)

        remove_redundant_transposes(builder.spec)

        layers = builder.spec.neuralNetwork.layers
        if expected_layer_num == None:
            self.assertTrue(len(layers) < num_layers_before)
        else:
            self.assertEqual(len(layers), expected_layer_num)

        # Mlmodel after
        mlmodel = MLModel(builder.spec)
        output_after = mlmodel.predict({"data": data})["out"]

        np.testing.assert_almost_equal(output_before, output_after, decimal=3)
    def test_transpose(self):
        def _build_and_test_network(input_size, transpose_layers, expected_layers):
            """
            Helper function for testing transpose removal.

            Args:
                input_size: Size of the input network tensor.
                transpose_layers: Array of transpose axes definitions.
                expected_layers: Array of indices into transpose_layers indicating
                    which of the transpose layers should be present after the
                    graph pass.
            """
            input_features = [("data", datatypes.Array(*input_size))]
            output_features = [("out", None)]
            builder = neural_network.NeuralNetworkBuilder(
                input_features, output_features
            )
            spec = builder.spec.neuralNetwork.layers

            last_layer = "data"
            for idx, axes in enumerate(transpose_layers):
                name = "t{}".format(idx)
                if idx == len(transpose_layers) - 1:
                    output_name = "out"
                else:
                    output_name = name + "_out"
                builder.add_transpose(
                    name=name, axes=axes, input_name=last_layer, output_name=output_name
                )
                last_layer = output_name

            spec = builder.spec.neuralNetwork
            # Check the network before the graph pass.
            for idx in range(len(transpose_layers)):
                np.testing.assert_equal(
                    "transpose", spec.layers[idx].WhichOneof("layer")
                )
            # Run the removal pass.
            remove_redundant_transposes(builder.spec)
            # Verify only the expected layers remain.
            np.testing.assert_equal(len(spec.layers), len(expected_layers))
            for output_layer_idx, input_layer_idx in enumerate(expected_layers):
                np.testing.assert_equal(
                    "transpose", spec.layers[output_layer_idx].WhichOneof("layer")
                )
                np.testing.assert_array_equal(
                    transpose_layers[input_layer_idx],
                    spec.layers[output_layer_idx].transpose.axes,
                )

        _build_and_test_network(
            input_size=[1, 10, 10],
            # These transposes are not inverses.
            transpose_layers=[[2, 0, 1], [2, 0, 1]],
            expected_layers=[0, 1],
        )

        _build_and_test_network(
            input_size=[1, 1, 10, 10, 3],
            # First two are the identity, then an extra.
            transpose_layers=[[2, 4, 1, 0, 3], [3, 2, 0, 4, 1], [1, 0, 2, 3, 4]],
            expected_layers=[2],
        )

        # A slightly more complicated test case where there are two transposes
        # in topological order, but are actually in parallel in the graph.
        builder = neural_network.NeuralNetworkBuilder(
            [("data", datatypes.Array(2, 4, 8))], [("out", None)]
        )
        last_layer = "data"
        builder.add_transpose(
            name="t1", axes=[0, 2, 1], input_name="data", output_name="t1"
        )
        builder.add_transpose(
            name="t2", axes=[0, 2, 1], input_name="data", output_name="t2"
        )
        builder.add_stack(name="stack", input_names=["t1", "t2"], output_name="out")
        spec = builder.spec.neuralNetwork
        # Run the removal pass.
        remove_redundant_transposes(builder.spec)
        # Verify nothing was removed.
        np.testing.assert_equal(len(spec.layers), 3)