def test_get_input_node_map_given_duplicates(self): """get_input_node_map should raise ValueError given duplicate names""" graph_def = testutils.get_sample_graph_def() relu = _get_node_by_name(graph_def, 'model/conv3/Relu') neg = rewrite.make_op_node('Neg', list(relu.input), name='kate') dup = rewrite.make_op_node('Exp', neg, name='model/conv3/BiasAdd') replace_nodes = { 'model/conv3/Relu': [neg, dup], } updated_graph = rewrite.update_graph_def(graph_def, replace_nodes, {}) self.assertRaises(ValueError, lambda: rewrite.get_input_node_map(updated_graph))
def test_update_graph_def_given_replaced_nodes(self): """update_graph_def should replace nodes mapped to new sub-graph""" graph_def = testutils.get_sample_graph_def() # let's replace the conv1 activation with log-sigmoid relu = _get_node_by_name(graph_def, 'model/conv1/Relu') neg = rewrite.make_op_node('Neg', list(relu.input), 'model/conv1/Neg') exp = rewrite.make_op_node('Exp', neg, 'model/conv1/Exp') add = rewrite.make_op_node('Add', [exp, 'one'], 'model/conv1/Add') inv = rewrite.make_op_node('Inv', add, 'model/conv1/Inv') replace_nodes = {'model/conv1/Relu': [neg, exp, add, inv]} updated_graph = rewrite.update_graph_def(graph_def, replace_nodes, {}) for node_name in replace_nodes.keys(): self.assertIsNone(_get_node_by_name(updated_graph, node_name)) for node in list(replace_nodes.values())[0]: self.assertIsNotNone(_get_node_by_name(updated_graph, node.name))
def test_update_graph_def_given_removed_nodes(self): """update_graph_def should remove nodes mapped to empty lists or None """ graph_def = testutils.get_sample_graph_def() # basically removes the Keras Conv2D-layer named 'conv3' remove_nodes = { 'model/conv3/Conv2D': [], 'model/conv3/BiasAdd': None, 'model/conv3/Relu': [] } updated_graph = rewrite.update_graph_def(graph_def, remove_nodes, {}) self.assertNotEqual(graph_def, updated_graph) # nodes must be removed for node_name in remove_nodes.keys(): self.assertIsNone(_get_node_by_name(updated_graph, node_name))
def test_update_graph_def_given_remapped_input(self): """update_graph_def should remap inputs given a non-empty mapping""" graph_def = testutils.get_sample_graph_def() # basically removes the Keras Conv2D-layer named 'conv3' remove_nodes = { 'model/conv3/Conv2D': [], 'model/conv3/BiasAdd': None, 'model/conv3/Relu': [] } # this time we also re-route the inputs proplerly remap_inputs = {'model/conv3/Relu': 'model/conv2/Relu'} updated_graph = rewrite.update_graph_def(graph_def, remove_nodes, remap_inputs) self.assertNotEqual(graph_def, updated_graph) # nodes must be removed for node_name in remove_nodes.keys(): self.assertIsNone(_get_node_by_name(updated_graph, node_name)) # inputs must be remapped for node in updated_graph.node: for key in remove_nodes.keys(): self.assertNotIn(key, node.input)
def test_update_graph_def_given_empty_args(self): """update_graph_def should copy input as-is given empty dicts""" graph_def = testutils.get_sample_graph_def() updated_graph = rewrite.update_graph_def(graph_def, {}, {}) self.assertEqual(graph_def, updated_graph)