def test_override_placeholder_shapes_dict(self): graph = build_graph(nodes_attributes, [('node_1', 'node_2'), ('node_2', 'op_output')], { 'node_2': { 'shape': None, 'op': 'Parameter' }, 'node_1': { 'shape': np.array([1, 3, 227, 227]), 'op': 'Parameter' } }, nodes_with_edges_only=True) placeholder_shape = np.array([1, 3, 224, 224]) user_shapes = { 'node_1': [{ 'shape': placeholder_shape }], 'node_2': [{ 'shape': placeholder_shape }], } override_placeholder_shapes(graph, user_shapes) res_shape = graph.node['node_1']['shape'] res_shape2 = graph.node['node_2']['shape'] self.assertTrue(np.array_equal(placeholder_shape, res_shape)) self.assertTrue(np.array_equal(placeholder_shape, res_shape2))
def test_override_placeholder_shapes_real_inputs_and_batch_2(self): """ Test case when batch is set, but shapes in user_shapes is None. """ graph = build_graph(self.nodes, self.edges) shapes = {'placeholder_1': [{'shape': None}], 'placeholder_2': [{'shape': None}]} batch = 4 graph.node['placeholder_2']['shape'] = np.array([1, 2, 3, 4]) graph.node['placeholder_2']['shape'] = np.array([1, 5, 6, 7]) override_placeholder_shapes(graph, shapes, batch) np.testing.assert_array_equal(graph.node['placeholder_1']['shape'], np.array([4, 2, 3, 4])) np.testing.assert_array_equal(graph.node['placeholder_2']['shape'], np.array([4, 5, 6, 7]))
def test_override_placeholder_shapes_batch_is_not_set(self): """ Test case when batch is not set. (shapes shouldn't change) """ graph = build_graph(self.nodes, self.edges) shapes = {} batch = None override_placeholder_shapes(graph, shapes, batch) res_shape_1 = graph.node['placeholder_1']['shape'] res_shape_2 = graph.node['placeholder_2']['shape'] self.assertTrue(np.array_equal(self.nodes['placeholder_1']['shape'], res_shape_1)) self.assertTrue(np.array_equal(self.nodes['placeholder_2']['shape'], res_shape_2))
def test_override_placeholder_shapes_real_inputs_and_batch(self): """ Test case when batch is set and shapes should overwrite by user shapes. """ graph = build_graph(self.nodes, self.edges) shapes = {'placeholder_1': [{'shape': np.array([1, 2, 3, 4])}], 'placeholder_2': [{'shape': np.array([1, 5, 6, 7])}]} batch = 4 override_placeholder_shapes(graph, shapes, batch) res_shape_1 = graph.node['placeholder_1']['shape'] res_shape_2 = graph.node['placeholder_2']['shape'] self.assertTrue(np.array_equal(res_shape_1, np.array([4, 2, 3, 4]))) self.assertTrue(np.array_equal(res_shape_2, np.array([4, 5, 6, 7])))
def test_override_placeholder_shapes(self): """ Test for case when user_shapes is not None, but it shouldn't rewrite shapes. """ graph = build_graph(nodes_attributes, [('node_1', 'node_2'), ('node_2', 'op_output') ], {'node_2': {'shape': None}, 'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Parameter'} }, nodes_with_edges_only=True) node_1_shape = np.array([1, 3, 227, 227]) user_dict = {'some_node': [{'shape': np.zeros((3))}]} override_placeholder_shapes(graph, user_dict) res_shape = graph.node['node_1']['shape'] self.assertTrue(np.array_equal(node_1_shape, res_shape))
def test_override_placeholder_shapes(self): """ Test for overriding shape in placeholder by shape from user_shapes. """ graph = build_graph(nodes_attributes, [('node_1', 'node_2'), ('node_2', 'op_output') ], {'node_2': {'shape': None}, 'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Parameter'} }, nodes_with_edges_only=True) ph_shape = np.array([1, 3, 224, 224]) user_dict = {'node_1': [{'shape': ph_shape}]} override_placeholder_shapes(graph, user_dict) res_shape = graph.node['node_1']['shape'] self.assertTrue(np.array_equal(ph_shape, res_shape))
def test_override_placeholder_no_shape(self): """ Test for case when user_shapes is not defined. """ graph = build_graph(nodes_attributes, [('node_1', 'node_2'), ('node_2', 'op_output') ], {'node_2': {'shape': None, 'op': 'Parameter'}, 'node_1': {'shape': np.array([1, 3, 227, 227]), 'op': 'Parameter'} }, nodes_with_edges_only=True) out = override_placeholder_shapes(graph, None) res_shape = graph.node['node_1']['shape'] placeholder_shape = np.array([1, 3, 227, 227]) self.assertIsNone(out) self.assertTrue(np.array_equal(placeholder_shape, res_shape))