def test_get_output_nodes(self): """Should return node info for outputs""" graph_def = testutils.get_sample_graph_def() actual = util.get_output_nodes(graph_def) expected = testutils.get_outputs(graph_def) self.assertEqual(len(actual), len(expected)) self.assertEqual(actual[0].name, expected[0].name)
def test_get_input_node_map_given_valid_graph(self): """get_input_node_map should accept valid graphs""" graph_def = testutils.get_sample_graph_def() input_nodes = rewrite.get_input_node_map(graph_def) self.assertGreater(len(input_nodes), 1) # randomly verify the existence of nodes in the map self.assertIn('model/conv1/BiasAdd', input_nodes) self.assertIn('model/flatten/Reshape', input_nodes) self.assertIn('model/output/MatMul', input_nodes)
def test_graph_to_function_v2_given_graph_def(self): """graph_def_to_function_v2 should accept graph_def""" graph_def = testutils.get_sample_graph_def( testutils.SIMPLE_MODEL_FILE_NAME) estimate = api.graph_to_function_v2(graph_def) x_ = 20 x = tf.constant([[x_]], dtype=tf.float32) y = as_scalar(estimate(x)) self.assertAlmostEqual(y, x_ * 5, delta=0.1)
def test_get_input_node_map_given_duplicates(self): """get_input_node_map should raise ValueError given duplicate names""" graph_def = testutils.get_sample_graph_def() relu = _get_node_by_name(graph_def, 'model/conv3/Relu') neg = rewrite.make_op_node('Neg', list(relu.input), name='kate') dup = rewrite.make_op_node('Exp', neg, name='model/conv3/BiasAdd') replace_nodes = { 'model/conv3/Relu': [neg, dup], } updated_graph = rewrite.update_graph_def(graph_def, replace_nodes, {}) self.assertRaises(ValueError, lambda: rewrite.get_input_node_map(updated_graph))
def test_update_graph_def_given_replaced_nodes(self): """update_graph_def should replace nodes mapped to new sub-graph""" graph_def = testutils.get_sample_graph_def() # let's replace the conv1 activation with log-sigmoid relu = _get_node_by_name(graph_def, 'model/conv1/Relu') neg = rewrite.make_op_node('Neg', list(relu.input), 'model/conv1/Neg') exp = rewrite.make_op_node('Exp', neg, 'model/conv1/Exp') add = rewrite.make_op_node('Add', [exp, 'one'], 'model/conv1/Add') inv = rewrite.make_op_node('Inv', add, 'model/conv1/Inv') replace_nodes = {'model/conv1/Relu': [neg, exp, add, inv]} updated_graph = rewrite.update_graph_def(graph_def, replace_nodes, {}) for node_name in replace_nodes.keys(): self.assertIsNone(_get_node_by_name(updated_graph, node_name)) for node in list(replace_nodes.values())[0]: self.assertIsNotNone(_get_node_by_name(updated_graph, node.name))
def test_update_graph_def_given_removed_nodes(self): """update_graph_def should remove nodes mapped to empty lists or None """ graph_def = testutils.get_sample_graph_def() # basically removes the Keras Conv2D-layer named 'conv3' remove_nodes = { 'model/conv3/Conv2D': [], 'model/conv3/BiasAdd': None, 'model/conv3/Relu': [] } updated_graph = rewrite.update_graph_def(graph_def, remove_nodes, {}) self.assertNotEqual(graph_def, updated_graph) # nodes must be removed for node_name in remove_nodes.keys(): self.assertIsNone(_get_node_by_name(updated_graph, node_name))
def test_rename_input_nodes(self): """rename_input_nodes should rename input nodes in-place""" model_file = testutils.get_path_to(testutils.SIMPLE_MODEL_FILE_NAME) graph_def = testutils.get_sample_graph_def(model_file) updated = util.rename_input_nodes(graph_def, {'x': 'scalar'}) # update should be in-place self.assertEqual(graph_def, updated) # inputs should be renamed self.assertEqual(util.get_input_nodes(updated)[0].name, 'scalar') # model should still work model = testutils.graph_to_model(updated) s = 18 scalar = tf.constant([[s]], dtype=tf.float32) result = model(scalar) value = result[0].numpy() # value = np.reshape(value, (1)) y = value[0] self.assertAlmostEqual(y, s*5, delta=0.1)
def test_rename_input_nodes_reject_invalid_args(self): """rename_input_nodes rejects invalid arguments""" model_file = testutils.get_path_to(testutils.SIMPLE_MODEL_FILE_NAME) graph_def = testutils.get_sample_graph_def(model_file) # reject unknown node self.assertRaises( ValueError, lambda: util.rename_input_nodes( graph_def, {'does-not-exist': 'scalar'})) # reject non-input node self.assertRaises( ValueError, lambda: util.rename_input_nodes( graph_def, {'Identity': 'scalar'})) # reject rename to existing node self.assertRaises( ValueError, lambda: util.rename_input_nodes( graph_def, {'x': 'Identity'})) # new name must differ from old name self.assertRaises( ValueError, lambda: util.rename_input_nodes( graph_def, {'x': 'x'}))
def test_update_graph_def_given_remapped_input(self): """update_graph_def should remap inputs given a non-empty mapping""" graph_def = testutils.get_sample_graph_def() # basically removes the Keras Conv2D-layer named 'conv3' remove_nodes = { 'model/conv3/Conv2D': [], 'model/conv3/BiasAdd': None, 'model/conv3/Relu': [] } # this time we also re-route the inputs proplerly remap_inputs = {'model/conv3/Relu': 'model/conv2/Relu'} updated_graph = rewrite.update_graph_def(graph_def, remove_nodes, remap_inputs) self.assertNotEqual(graph_def, updated_graph) # nodes must be removed for node_name in remove_nodes.keys(): self.assertIsNone(_get_node_by_name(updated_graph, node_name)) # inputs must be remapped for node in updated_graph.node: for key in remove_nodes.keys(): self.assertNotIn(key, node.input)
def test_rename_output_nodes_append_identity(self): """rename_output_nodes should work for outputs that aren't Identity""" model_file = testutils.get_path_to(testutils.SIMPLE_MODEL_FILE_NAME) graph_def = testutils.get_sample_graph_def(model_file) # some open-heart surgery on the model to remove the "Identity" output idx = [i for (i, n) in enumerate(graph_def.node) if n.op == 'Identity'] del graph_def.node[idx[0]] output = util.get_output_nodes(graph_def)[0].name updated = util.rename_output_nodes(graph_def, {output: 'estimate'}) # update should be in-place self.assertEqual(graph_def, updated) # outputs should be renamed self.assertEqual(util.get_output_nodes(updated)[0].name, 'estimate') # model should still work model = testutils.graph_to_model(updated) s = 18 scalar = tf.constant([[s]], dtype=tf.float32) result = model(scalar) value = result[0].numpy() # value = np.reshape(value, (1)) y = value[0] self.assertAlmostEqual(y, s*5, delta=0.1)
def test_update_graph_def_given_empty_args(self): """update_graph_def should copy input as-is given empty dicts""" graph_def = testutils.get_sample_graph_def() updated_graph = rewrite.update_graph_def(graph_def, {}, {}) self.assertEqual(graph_def, updated_graph)
def test_validate_supported_ops_given_valid_graph(self): """validate_supported_ops should accept valid graph_def""" graph_def = testutils.get_sample_graph_def() rewrite.validate_supported_ops(graph_def)
def test_replace_matching_nodes(self): # case 1: unchanged copy if no matches graph_def = testutils.get_sample_graph_def() def _is_prelu(node): return node.op == 'Prelu' def _remove_node(node, map, mods): return [] updated_graph_def, modifiers = rewrite.replace_matching_nodes( graph_def, predicate=_is_prelu, transform=_remove_node) self.assertEqual(modifiers, {}) self.assertEqual(updated_graph_def, graph_def) # case 2: replaces matching nodes and keeps graph valid name_of_node_to_replace = 'model/conv2/Relu' new_name_of_replaced_node = '' def _must_replace(node): return node.name == name_of_node_to_replace def _convert_to_log_sigmoid(node, input_map, modifiers): """replace Relu with logarithmic sigmoid 1/(1+exp(-x))""" def _get_name(suffix): return rewrite.generate_name_from(node.name, input_map, f'logSigmoid/{suffix}') nonlocal new_name_of_replaced_node # -x neg = rewrite.make_op_node('Neg', list(node.input), name=_get_name('Neg')) # exp(-x) exp = rewrite.make_op_node('Exp', neg, name=_get_name('Exp')) # constant tensor holding "1" res = rewrite.make_const_node(np.array([1], dtype=np.float32), name=_get_name('Var/resource')) # variable holding "1" one = rewrite.make_op_node('Identity', res, _get_name('Var')) # 1+exp(-x) add = rewrite.make_op_node('Add', [one, exp], _get_name('Add')) # 1/(1+exp-x) inv = rewrite.make_op_node('Inv', add, _get_name('Inv')) new_name_of_replaced_node = inv.name # remember the output name return [neg, exp, res, one, add, inv] updated_graph_def, modifiers = rewrite.replace_matching_nodes( graph_def, predicate=_must_replace, transform=_convert_to_log_sigmoid) # replaced node must have been removed updated_nodes = rewrite.get_input_node_map(updated_graph_def) self.assertNotIn(name_of_node_to_replace, updated_nodes) # replaced node must not be referenced for _, node in updated_nodes.items(): # nodes with inputs only if node.op not in ('Const', 'Placeholder'): self.assertNotIn(name_of_node_to_replace, node.input) # referenced to replaced node must point to last node in replacement original_nodes = rewrite.get_input_node_map(graph_def) replaced_references = [ node.name for node in original_nodes.values() if name_of_node_to_replace in node.input ] for node_name in replaced_references: node = updated_nodes[node_name] self.assertIn(new_name_of_replaced_node, node.input)
def test_get_output_tensors(self): """Should return node info for outputs""" graph_def = testutils.get_sample_graph_def() actual = util.get_output_tensors(graph_def) expected = [(n.name + ':0') for n in testutils.get_outputs(graph_def)] self.assertEqual(actual, expected)
def test_graph_def_to_graph_v1(self): """graph_def_to_graph_v1 should return tf.Graph for inference""" graph_def = testutils.get_sample_graph_def( testutils.SIMPLE_MODEL_FILE_NAME) graph = api.graph_def_to_graph_v1(graph_def) self.assertIsInstance(graph, tf.Graph)