Example #1
0
    def test_negative_variadic_split_axis(self, axis):
        lengths = int64_array([2, 13, 10])
        graph = build_graph(
            self.nodes, self.edges, {
                'split_input_data': {
                    'shape': int64_array([2, 12, 25, 30])
                },
                'split_axis_data': {
                    'value': axis
                },
                'split_lengths_data': {
                    'value': lengths
                },
                'split_op': {
                    'out_ports_count': 4
                },
            })
        node = Node(graph, 'split_op')
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)

        try:
            VariadicSplit.infer(node)
        except AssertionError as e:
            self.assertTrue(
                e.args[0] ==
                'VariadicSplit `axis` should be scalar or tensor with shape [1], '
                'but it`s not for node split_op')
Example #2
0
    def test_variadic_split_axis(self, axis):
        lengths = int64_array([2, 13, 10])
        graph = build_graph(
            self.nodes, self.edges, {
                'split_input_data': {
                    'shape': int64_array([2, 12, 25, 30])
                },
                'split_axis_data': {
                    'value': axis
                },
                'split_lengths_data': {
                    'value': lengths
                },
                'split_op': {
                    'out_ports_count': 4
                },
            })
        node = Node(graph, 'split_op')
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)

        VariadicSplit.infer(node)

        ont_nodes_count = len(node.out_edges())
        self.assertTrue(ont_nodes_count == 3)
        for out in range(ont_nodes_count):
            self.assertTrue(
                np.all(
                    node.out_node(out).shape == int64_array(
                        [2, 12, lengths[out], 30])))
Example #3
0
    def test_variadic_split_value_inference_with_uint32(self):
        axis = int64_array(2)
        # because sum of Python int and Numpy np.uint64 gives float64

        # but np.split accepts only integers and raises error for floats
        # therefore needed to explicitly cast np.split arguments into integer
        # added this test for that case
        lengths = mo_array([2, 13, 10], dtype=np.uint64)
        input_shape = mo_array([2, 12, 25, 30])
        input_value = np.zeros(input_shape)

        graph = build_graph(
            self.nodes, self.edges, {
                'split_input_data': {
                    'shape': input_shape,
                    'value': input_value
                },
                'split_axis_data': {
                    'value': axis
                },
                'split_lengths_data': {
                    'value': lengths
                },
                'split_op': {
                    'out_ports_count': 4
                },
            })
        node = Node(graph, 'split_op')
        for p in range(len(node.out_edges()), node.out_ports_count):
            node.add_output_port(p)

        VariadicSplit.infer(node)

        ont_nodes_count = len(node.out_edges())
        self.assertTrue(ont_nodes_count == 3)
        for out in range(ont_nodes_count):
            self.assertTrue(
                np.all(
                    node.out_node(out).shape == int64_array(
                        [2, 12, lengths[out], 30])))