def test_out_port_no_data(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out) new_input_shape = np.array([1, 2, 3, 4]) graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:], new_nodes_with_attrs=[ ('input_node', { 'kind': 'op', 'op': 'Parameter', 'shape': new_input_shape }) ], new_edges_with_attrs=[ ('input_node', 'future_input', { 'in': 0, 'out': 0 }) ]) add_input_op(graph, 'op_node', 1, data=False, shape=new_input_shape, is_out_port=True) graph.remove_edge('op_node', 'future_input') (flag, resp) = compare_graphs(graph, graph_ref, last_node='another_node') self.assertTrue(flag, resp) (flag, resp) = compare_graphs(graph, graph_ref, last_node='future_input') self.assertTrue(flag, resp)
def test_in_port_with_data(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges) graph.stage = 'middle' new_input_shape = np.array([1, 2, 3, 4]) graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges[1:], new_nodes_with_attrs=[ ('input_node', { 'kind': 'op', 'op': 'Parameter', 'shape': new_input_shape }), ('input_data', { 'kind': 'data' }) ], new_edges_with_attrs=[ ('input_node', 'input_data', { 'in': 0, 'out': 0 }), ('input_data', 'op_node', { 'in': 1, 'out': 0 }) ]) add_input_op(graph, 'op_node', 1, data=True, shape=new_input_shape) graph.remove_edge('future_input', 'op_node') (flag, resp) = compare_graphs(graph, graph_ref, last_node='op_node') self.assertTrue(flag, resp)