def test_negative_variadic_split_axis(self, axis): lengths = int64_array([2, 13, 10]) graph = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': int64_array([2, 12, 25, 30]) }, 'split_axis_data': { 'value': axis }, 'split_lengths_data': { 'value': lengths }, 'split_op': { 'out_ports_count': 4 }, }) node = Node(graph, 'split_op') for p in range(len(node.out_edges()), node.out_ports_count): node.add_output_port(p) try: VariadicSplit.infer(node) except AssertionError as e: self.assertTrue( e.args[0] == 'VariadicSplit `axis` should be scalar or tensor with shape [1], ' 'but it`s not for node split_op')
def test_variadic_split_axis(self, axis): lengths = int64_array([2, 13, 10]) graph = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': int64_array([2, 12, 25, 30]) }, 'split_axis_data': { 'value': axis }, 'split_lengths_data': { 'value': lengths }, 'split_op': { 'out_ports_count': 4 }, }) node = Node(graph, 'split_op') for p in range(len(node.out_edges()), node.out_ports_count): node.add_output_port(p) VariadicSplit.infer(node) ont_nodes_count = len(node.out_edges()) self.assertTrue(ont_nodes_count == 3) for out in range(ont_nodes_count): self.assertTrue( np.all( node.out_node(out).shape == int64_array( [2, 12, lengths[out], 30])))
def test_variadic_split_value_inference_with_uint32(self): axis = int64_array(2) # because sum of Python int and Numpy np.uint64 gives float64 # but np.split accepts only integers and raises error for floats # therefore needed to explicitly cast np.split arguments into integer # added this test for that case lengths = mo_array([2, 13, 10], dtype=np.uint64) input_shape = mo_array([2, 12, 25, 30]) input_value = np.zeros(input_shape) graph = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': input_shape, 'value': input_value }, 'split_axis_data': { 'value': axis }, 'split_lengths_data': { 'value': lengths }, 'split_op': { 'out_ports_count': 4 }, }) node = Node(graph, 'split_op') for p in range(len(node.out_edges()), node.out_ports_count): node.add_output_port(p) VariadicSplit.infer(node) ont_nodes_count = len(node.out_edges()) self.assertTrue(ont_nodes_count == 3) for out in range(ont_nodes_count): self.assertTrue( np.all( node.out_node(out).shape == int64_array( [2, 12, lengths[out], 30])))