def test_tf_space_to_depth_infer_nhwc(self): graph = build_graph(nodes, edges) graph.graph['layout'] = 'NHWC' std_node = Node(graph, 'StD') SpaceToDepth.infer(std_node) exp_shape = np.array([1, 1024, 576, 256]) res_shape = graph.node['out_data_node']['shape'] self.assertTrue(np.array_equal(exp_shape, res_shape))
def test_tf_space_to_depth_infer_nchw(self): graph = build_graph(nodes, edges) graph.graph['layout'] = 'NCHW' graph.node['in_data_node']['shape'] = np.array([1, 64, 2048, 1152]) std_node = Node(graph, 'StD') SpaceToDepth.infer(std_node) exp_shape = np.array([1, 256, 1024, 576]) res_shape = graph.node['out_data_node']['shape'] self.assertTrue(np.array_equal(exp_shape, res_shape))
def test_tf_space_to_depth_infer_nchw_dynamic(self): graph = build_graph(nodes, edges) graph.graph['layout'] = 'NCHW' graph.node['in_data_node']['shape'] = shape_array( [1, 64, dynamic_dimension_value, 1152]) std_node = Node(graph, 'StD') SpaceToDepth.infer(std_node) exp_shape = shape_array([1, 256, dynamic_dimension_value, 576]) res_shape = graph.node['out_data_node']['shape'] self.assertTrue(strict_compare_tensors(exp_shape, res_shape))