def test_tf_split_infer_input_shape_is_None(self): split_node = Node(self.graph, 'split_node') self.graph.node['split_dim']['value'] = np.array(1) tf_split_infer(split_node) for out_node in split_node.out_nodes().values(): self.assertIsNone(out_node.shape)
def test_tf_split_infer_unknown_index(self): split_node = Node(self.graph, 'split_node') self.graph.node['data_to_split']['shape'] = int64_array( [2, 12, 25, 30]) tf_split_infer(split_node) for out_node in split_node.out_nodes().values(): self.assertIsNone(out_node.shape)
def test_tf_split_infer_wrong_num_split(self): split_node = Node(self.graph, 'split_node') self.graph.node['split_dim']['value'] = np.array(0) self.graph.node['data_to_split']['shape'] = int64_array( [2, 12, 25, 30]) tf_split_infer(split_node) for out_node in split_node.out_nodes().values(): self.assertIsNone(out_node.shape)
def test_tf_split_infer_negative_index(self): split_node = Node(self.graph, 'split_node') self.graph.node['split_dim']['value'] = np.array(-3) self.graph.node['data_to_split']['shape'] = int64_array( [2, 12, 25, 30]) tf_split_infer(split_node) exp_shape = int64_array([2, 4, 25, 30]) for out_node in split_node.out_nodes().values(): self.assertTrue(np.all(exp_shape == out_node.shape)) self.assertEqual(1, split_node.input_port)