コード例 #1
0
ファイル: util_test.py プロジェクト: PapaEcureuil/tfjs-to-tf
 def test_get_output_nodes(self):
     """Should return node info for outputs"""
     graph_def = testutils.get_sample_graph_def()
     actual = util.get_output_nodes(graph_def)
     expected = testutils.get_outputs(graph_def)
     self.assertEqual(len(actual), len(expected))
     self.assertEqual(actual[0].name, expected[0].name)
コード例 #2
0
 def test_get_input_node_map_given_valid_graph(self):
     """get_input_node_map should accept valid graphs"""
     graph_def = testutils.get_sample_graph_def()
     input_nodes = rewrite.get_input_node_map(graph_def)
     self.assertGreater(len(input_nodes), 1)
     # randomly verify the existence of nodes in the map
     self.assertIn('model/conv1/BiasAdd', input_nodes)
     self.assertIn('model/flatten/Reshape', input_nodes)
     self.assertIn('model/output/MatMul', input_nodes)
コード例 #3
0
 def test_graph_to_function_v2_given_graph_def(self):
     """graph_def_to_function_v2 should accept graph_def"""
     graph_def = testutils.get_sample_graph_def(
         testutils.SIMPLE_MODEL_FILE_NAME)
     estimate = api.graph_to_function_v2(graph_def)
     x_ = 20
     x = tf.constant([[x_]], dtype=tf.float32)
     y = as_scalar(estimate(x))
     self.assertAlmostEqual(y, x_ * 5, delta=0.1)
コード例 #4
0
 def test_get_input_node_map_given_duplicates(self):
     """get_input_node_map should raise ValueError given duplicate names"""
     graph_def = testutils.get_sample_graph_def()
     relu = _get_node_by_name(graph_def, 'model/conv3/Relu')
     neg = rewrite.make_op_node('Neg', list(relu.input), name='kate')
     dup = rewrite.make_op_node('Exp', neg, name='model/conv3/BiasAdd')
     replace_nodes = {
         'model/conv3/Relu': [neg, dup],
     }
     updated_graph = rewrite.update_graph_def(graph_def, replace_nodes, {})
     self.assertRaises(ValueError,
                       lambda: rewrite.get_input_node_map(updated_graph))
コード例 #5
0
 def test_update_graph_def_given_replaced_nodes(self):
     """update_graph_def should replace nodes mapped to new sub-graph"""
     graph_def = testutils.get_sample_graph_def()
     # let's replace the conv1 activation with log-sigmoid
     relu = _get_node_by_name(graph_def, 'model/conv1/Relu')
     neg = rewrite.make_op_node('Neg', list(relu.input), 'model/conv1/Neg')
     exp = rewrite.make_op_node('Exp', neg, 'model/conv1/Exp')
     add = rewrite.make_op_node('Add', [exp, 'one'], 'model/conv1/Add')
     inv = rewrite.make_op_node('Inv', add, 'model/conv1/Inv')
     replace_nodes = {'model/conv1/Relu': [neg, exp, add, inv]}
     updated_graph = rewrite.update_graph_def(graph_def, replace_nodes, {})
     for node_name in replace_nodes.keys():
         self.assertIsNone(_get_node_by_name(updated_graph, node_name))
     for node in list(replace_nodes.values())[0]:
         self.assertIsNotNone(_get_node_by_name(updated_graph, node.name))
コード例 #6
0
 def test_update_graph_def_given_removed_nodes(self):
     """update_graph_def should remove nodes mapped to empty lists
        or None
     """
     graph_def = testutils.get_sample_graph_def()
     # basically removes the Keras Conv2D-layer named 'conv3'
     remove_nodes = {
         'model/conv3/Conv2D': [],
         'model/conv3/BiasAdd': None,
         'model/conv3/Relu': []
     }
     updated_graph = rewrite.update_graph_def(graph_def, remove_nodes, {})
     self.assertNotEqual(graph_def, updated_graph)
     # nodes must be removed
     for node_name in remove_nodes.keys():
         self.assertIsNone(_get_node_by_name(updated_graph, node_name))
コード例 #7
0
ファイル: util_test.py プロジェクト: taotaoyuhust/tfjs-to-tf
 def test_rename_input_nodes(self):
     """rename_input_nodes should rename input nodes in-place"""
     model_file = testutils.get_path_to(testutils.SIMPLE_MODEL_FILE_NAME)
     graph_def = testutils.get_sample_graph_def(model_file)
     updated = util.rename_input_nodes(graph_def, {'x': 'scalar'})
     # update should be in-place
     self.assertEqual(graph_def, updated)
     # inputs should be renamed
     self.assertEqual(util.get_input_nodes(updated)[0].name, 'scalar')
     # model should still work
     model = testutils.graph_to_model(updated)
     s = 18
     scalar = tf.constant([[s]], dtype=tf.float32)
     result = model(scalar)
     value = result[0].numpy()
     # value = np.reshape(value, (1))
     y = value[0]
     self.assertAlmostEqual(y, s*5, delta=0.1)
コード例 #8
0
ファイル: util_test.py プロジェクト: taotaoyuhust/tfjs-to-tf
 def test_rename_input_nodes_reject_invalid_args(self):
     """rename_input_nodes rejects invalid arguments"""
     model_file = testutils.get_path_to(testutils.SIMPLE_MODEL_FILE_NAME)
     graph_def = testutils.get_sample_graph_def(model_file)
     # reject unknown node
     self.assertRaises(
         ValueError, lambda: util.rename_input_nodes(
             graph_def, {'does-not-exist': 'scalar'}))
     # reject non-input node
     self.assertRaises(
         ValueError, lambda: util.rename_input_nodes(
             graph_def, {'Identity': 'scalar'}))
     # reject rename to existing node
     self.assertRaises(
         ValueError, lambda: util.rename_input_nodes(
             graph_def, {'x': 'Identity'}))
     # new name must differ from old name
     self.assertRaises(
         ValueError, lambda: util.rename_input_nodes(
             graph_def, {'x': 'x'}))
コード例 #9
0
 def test_update_graph_def_given_remapped_input(self):
     """update_graph_def should remap inputs given a non-empty mapping"""
     graph_def = testutils.get_sample_graph_def()
     # basically removes the Keras Conv2D-layer named 'conv3'
     remove_nodes = {
         'model/conv3/Conv2D': [],
         'model/conv3/BiasAdd': None,
         'model/conv3/Relu': []
     }
     # this time we also re-route the inputs proplerly
     remap_inputs = {'model/conv3/Relu': 'model/conv2/Relu'}
     updated_graph = rewrite.update_graph_def(graph_def, remove_nodes,
                                              remap_inputs)
     self.assertNotEqual(graph_def, updated_graph)
     # nodes must be removed
     for node_name in remove_nodes.keys():
         self.assertIsNone(_get_node_by_name(updated_graph, node_name))
     # inputs must be remapped
     for node in updated_graph.node:
         for key in remove_nodes.keys():
             self.assertNotIn(key, node.input)
コード例 #10
0
ファイル: util_test.py プロジェクト: taotaoyuhust/tfjs-to-tf
 def test_rename_output_nodes_append_identity(self):
     """rename_output_nodes should work for outputs that aren't Identity"""
     model_file = testutils.get_path_to(testutils.SIMPLE_MODEL_FILE_NAME)
     graph_def = testutils.get_sample_graph_def(model_file)
     # some open-heart surgery on the model to remove the "Identity" output
     idx = [i for (i, n) in enumerate(graph_def.node) if n.op == 'Identity']
     del graph_def.node[idx[0]]
     output = util.get_output_nodes(graph_def)[0].name
     updated = util.rename_output_nodes(graph_def, {output: 'estimate'})
     # update should be in-place
     self.assertEqual(graph_def, updated)
     # outputs should be renamed
     self.assertEqual(util.get_output_nodes(updated)[0].name, 'estimate')
     # model should still work
     model = testutils.graph_to_model(updated)
     s = 18
     scalar = tf.constant([[s]], dtype=tf.float32)
     result = model(scalar)
     value = result[0].numpy()
     # value = np.reshape(value, (1))
     y = value[0]
     self.assertAlmostEqual(y, s*5, delta=0.1)
コード例 #11
0
 def test_update_graph_def_given_empty_args(self):
     """update_graph_def should copy input as-is given empty dicts"""
     graph_def = testutils.get_sample_graph_def()
     updated_graph = rewrite.update_graph_def(graph_def, {}, {})
     self.assertEqual(graph_def, updated_graph)
コード例 #12
0
 def test_validate_supported_ops_given_valid_graph(self):
     """validate_supported_ops should accept valid graph_def"""
     graph_def = testutils.get_sample_graph_def()
     rewrite.validate_supported_ops(graph_def)
コード例 #13
0
    def test_replace_matching_nodes(self):
        # case 1: unchanged copy if no matches
        graph_def = testutils.get_sample_graph_def()

        def _is_prelu(node):
            return node.op == 'Prelu'

        def _remove_node(node, map, mods):
            return []

        updated_graph_def, modifiers = rewrite.replace_matching_nodes(
            graph_def, predicate=_is_prelu, transform=_remove_node)
        self.assertEqual(modifiers, {})
        self.assertEqual(updated_graph_def, graph_def)
        # case 2: replaces matching nodes and keeps graph valid
        name_of_node_to_replace = 'model/conv2/Relu'
        new_name_of_replaced_node = ''

        def _must_replace(node):
            return node.name == name_of_node_to_replace

        def _convert_to_log_sigmoid(node, input_map, modifiers):
            """replace Relu with logarithmic sigmoid 1/(1+exp(-x))"""
            def _get_name(suffix):
                return rewrite.generate_name_from(node.name, input_map,
                                                  f'logSigmoid/{suffix}')

            nonlocal new_name_of_replaced_node
            # -x
            neg = rewrite.make_op_node('Neg',
                                       list(node.input),
                                       name=_get_name('Neg'))
            # exp(-x)
            exp = rewrite.make_op_node('Exp', neg, name=_get_name('Exp'))
            # constant tensor holding "1"
            res = rewrite.make_const_node(np.array([1], dtype=np.float32),
                                          name=_get_name('Var/resource'))
            # variable holding "1"
            one = rewrite.make_op_node('Identity', res, _get_name('Var'))
            # 1+exp(-x)
            add = rewrite.make_op_node('Add', [one, exp], _get_name('Add'))
            # 1/(1+exp-x)
            inv = rewrite.make_op_node('Inv', add, _get_name('Inv'))
            new_name_of_replaced_node = inv.name  # remember the output name
            return [neg, exp, res, one, add, inv]

        updated_graph_def, modifiers = rewrite.replace_matching_nodes(
            graph_def,
            predicate=_must_replace,
            transform=_convert_to_log_sigmoid)

        # replaced node must have been removed
        updated_nodes = rewrite.get_input_node_map(updated_graph_def)
        self.assertNotIn(name_of_node_to_replace, updated_nodes)
        # replaced node must not be referenced
        for _, node in updated_nodes.items():
            # nodes with inputs only
            if node.op not in ('Const', 'Placeholder'):
                self.assertNotIn(name_of_node_to_replace, node.input)

        # referenced to replaced node must point to last node in replacement
        original_nodes = rewrite.get_input_node_map(graph_def)
        replaced_references = [
            node.name for node in original_nodes.values()
            if name_of_node_to_replace in node.input
        ]
        for node_name in replaced_references:
            node = updated_nodes[node_name]
            self.assertIn(new_name_of_replaced_node, node.input)
コード例 #14
0
ファイル: util_test.py プロジェクト: PapaEcureuil/tfjs-to-tf
 def test_get_output_tensors(self):
     """Should return node info for outputs"""
     graph_def = testutils.get_sample_graph_def()
     actual = util.get_output_tensors(graph_def)
     expected = [(n.name + ':0') for n in testutils.get_outputs(graph_def)]
     self.assertEqual(actual, expected)
コード例 #15
0
 def test_graph_def_to_graph_v1(self):
     """graph_def_to_graph_v1 should return tf.Graph for inference"""
     graph_def = testutils.get_sample_graph_def(
         testutils.SIMPLE_MODEL_FILE_NAME)
     graph = api.graph_def_to_graph_v1(graph_def)
     self.assertIsInstance(graph, tf.Graph)