Exemplo n.º 1
0
    def test_split(self, tensor_splits, split_as_arg):
        # Output Size: 1 - inf
        axis, split_info, splits = tensor_splits

        split_as_arg = True

        if split_as_arg:
            input_tensors = [np.concatenate(splits, axis=axis)]
            kwargs = dict(axis=axis, split=split_info, num_output=len(splits))
        else:
            input_tensors = [np.concatenate(splits, axis=axis), split_info]
            kwargs = dict(axis=axis, num_output=len(splits))
        result = Functional.Split(*input_tensors, **kwargs)

        def split_ref(input, split=split_info):
            s = np.cumsum([0] + list(split))
            return [
                np.array(input.take(np.arange(s[i], s[i + 1]), axis=axis))
                for i in range(len(split))
            ]

        result_ref = split_ref(*input_tensors)
        for i, ref in enumerate(result_ref):
            np.testing.assert_array_equal(
                result[i], ref, err_msg='Functional Relu result mismatch'
            )
Exemplo n.º 2
0
    def test_concat(self, tensor_splits):
        # Input Size: 1 -> inf
        axis, _, splits = tensor_splits
        concat_result, split_info = Functional.Concat(*splits, axis=axis)

        concat_result_ref = np.concatenate(splits, axis=axis)
        split_info_ref = np.array([a.shape[axis] for a in splits])

        np.testing.assert_array_equal(
            concat_result,
            concat_result_ref,
            err_msg='Functional Concat result mismatch')

        np.testing.assert_array_equal(
            split_info,
            split_info_ref,
            err_msg='Functional Concat split info mismatch')
Exemplo n.º 3
0
    def test_relu(self, X, engine):
        X += 0.02 * np.sign(X)
        X[X == 0.0] += 0.02
        output = Functional.Relu(X)
        Y_l = output[0]
        Y_d = output["output_0"]

        with workspace.WorkspaceGuard("tmp_workspace"):
            op = core.CreateOperator("Relu", ["X"], ["Y"], engine=engine)
            workspace.FeedBlob("X", X)
            workspace.RunOperatorOnce(op)
            Y_ref = workspace.FetchBlob("Y")

        np.testing.assert_array_equal(
            Y_l, Y_ref, err_msg='Functional Relu result mismatch')

        np.testing.assert_array_equal(
            Y_d, Y_ref, err_msg='Functional Relu result mismatch')