def test_get_input_node_map_given_duplicates(self):
     """get_input_node_map should raise ValueError given duplicate names"""
     graph_def = testutils.get_sample_graph_def()
     relu = _get_node_by_name(graph_def, 'model/conv3/Relu')
     neg = rewrite.make_op_node('Neg', list(relu.input), name='kate')
     dup = rewrite.make_op_node('Exp', neg, name='model/conv3/BiasAdd')
     replace_nodes = {
         'model/conv3/Relu': [neg, dup],
     }
     updated_graph = rewrite.update_graph_def(graph_def, replace_nodes, {})
     self.assertRaises(ValueError,
                       lambda: rewrite.get_input_node_map(updated_graph))
 def test_update_graph_def_given_replaced_nodes(self):
     """update_graph_def should replace nodes mapped to new sub-graph"""
     graph_def = testutils.get_sample_graph_def()
     # let's replace the conv1 activation with log-sigmoid
     relu = _get_node_by_name(graph_def, 'model/conv1/Relu')
     neg = rewrite.make_op_node('Neg', list(relu.input), 'model/conv1/Neg')
     exp = rewrite.make_op_node('Exp', neg, 'model/conv1/Exp')
     add = rewrite.make_op_node('Add', [exp, 'one'], 'model/conv1/Add')
     inv = rewrite.make_op_node('Inv', add, 'model/conv1/Inv')
     replace_nodes = {'model/conv1/Relu': [neg, exp, add, inv]}
     updated_graph = rewrite.update_graph_def(graph_def, replace_nodes, {})
     for node_name in replace_nodes.keys():
         self.assertIsNone(_get_node_by_name(updated_graph, node_name))
     for node in list(replace_nodes.values())[0]:
         self.assertIsNotNone(_get_node_by_name(updated_graph, node.name))
 def test_update_graph_def_given_removed_nodes(self):
     """update_graph_def should remove nodes mapped to empty lists
        or None
     """
     graph_def = testutils.get_sample_graph_def()
     # basically removes the Keras Conv2D-layer named 'conv3'
     remove_nodes = {
         'model/conv3/Conv2D': [],
         'model/conv3/BiasAdd': None,
         'model/conv3/Relu': []
     }
     updated_graph = rewrite.update_graph_def(graph_def, remove_nodes, {})
     self.assertNotEqual(graph_def, updated_graph)
     # nodes must be removed
     for node_name in remove_nodes.keys():
         self.assertIsNone(_get_node_by_name(updated_graph, node_name))
 def test_update_graph_def_given_remapped_input(self):
     """update_graph_def should remap inputs given a non-empty mapping"""
     graph_def = testutils.get_sample_graph_def()
     # basically removes the Keras Conv2D-layer named 'conv3'
     remove_nodes = {
         'model/conv3/Conv2D': [],
         'model/conv3/BiasAdd': None,
         'model/conv3/Relu': []
     }
     # this time we also re-route the inputs proplerly
     remap_inputs = {'model/conv3/Relu': 'model/conv2/Relu'}
     updated_graph = rewrite.update_graph_def(graph_def, remove_nodes,
                                              remap_inputs)
     self.assertNotEqual(graph_def, updated_graph)
     # nodes must be removed
     for node_name in remove_nodes.keys():
         self.assertIsNone(_get_node_by_name(updated_graph, node_name))
     # inputs must be remapped
     for node in updated_graph.node:
         for key in remove_nodes.keys():
             self.assertNotIn(key, node.input)
 def test_update_graph_def_given_empty_args(self):
     """update_graph_def should copy input as-is given empty dicts"""
     graph_def = testutils.get_sample_graph_def()
     updated_graph = rewrite.update_graph_def(graph_def, {}, {})
     self.assertEqual(graph_def, updated_graph)