示例#1
0
    def test_tf_split_infer_input_shape_is_None(self):
        split_node = Node(self.graph, 'split_node')
        self.graph.node['split_dim']['value'] = np.array(1)

        tf_split_infer(split_node)
        for out_node in split_node.out_nodes().values():
            self.assertIsNone(out_node.shape)
示例#2
0
    def test_tf_split_infer_unknown_index(self):
        split_node = Node(self.graph, 'split_node')
        self.graph.node['data_to_split']['shape'] = int64_array(
            [2, 12, 25, 30])

        tf_split_infer(split_node)
        for out_node in split_node.out_nodes().values():
            self.assertIsNone(out_node.shape)
示例#3
0
    def test_tf_split_infer_wrong_num_split(self):
        split_node = Node(self.graph, 'split_node')
        self.graph.node['split_dim']['value'] = np.array(0)
        self.graph.node['data_to_split']['shape'] = int64_array(
            [2, 12, 25, 30])

        tf_split_infer(split_node)
        for out_node in split_node.out_nodes().values():
            self.assertIsNone(out_node.shape)
示例#4
0
    def test_tf_split_infer_negative_index(self):
        split_node = Node(self.graph, 'split_node')
        self.graph.node['split_dim']['value'] = np.array(-3)
        self.graph.node['data_to_split']['shape'] = int64_array(
            [2, 12, 25, 30])

        tf_split_infer(split_node)
        exp_shape = int64_array([2, 4, 25, 30])
        for out_node in split_node.out_nodes().values():
            self.assertTrue(np.all(exp_shape == out_node.shape))
        self.assertEqual(1, split_node.input_port)