Пример #1
0
def _build_with_sig_def(sess, graph, sig_def):
    # pylint: disable=protected-access
    assert sig_def, 'signature_def must not be None'

    with sess.as_default(), graph.as_default():
        feed_mapping = {}
        feed_names = []
        for sigdef_key, tnsr_info in sig_def.inputs.items():
            tnsr_name = tnsr_info.name
            feed_mapping[sigdef_key] = tnsr_name
            feed_names.append(tnsr_name)

        fetch_mapping = {}
        fetch_names = []
        for sigdef_key, tnsr_info in sig_def.outputs.items():
            tnsr_name = tnsr_info.name
            fetch_mapping[sigdef_key] = tnsr_name
            fetch_names.append(tnsr_name)

        for tnsr_name in feed_names:
            assert tfx.get_op(tnsr_name, graph), \
                'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
        fetches = [
            tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names
        ]
        graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

    return TFInputGraph(graph_def=graph_def,
                        input_tensor_name_from_signature=feed_mapping,
                        output_tensor_name_from_signature=fetch_mapping)
Пример #2
0
def _build_with_sig_def(sess, graph, sig_def):
    # pylint: disable=protected-access
    assert sig_def, 'signature_def must not be None'

    with sess.as_default(), graph.as_default():
        feed_mapping = {}
        feed_names = []
        for sigdef_key, tnsr_info in sig_def.inputs.items():
            tnsr_name = tnsr_info.name
            feed_mapping[sigdef_key] = tnsr_name
            feed_names.append(tnsr_name)

        fetch_mapping = {}
        fetch_names = []
        for sigdef_key, tnsr_info in sig_def.outputs.items():
            tnsr_name = tnsr_info.name
            fetch_mapping[sigdef_key] = tnsr_name
            fetch_names.append(tnsr_name)

        for tnsr_name in feed_names:
            assert tfx.get_op(tnsr_name, graph), \
                'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
        fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names]
        graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

    return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=feed_mapping,
                        output_tensor_name_from_signature=fetch_mapping)
Пример #3
0
def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names):
    assert feed_names is not None, "must provide feed_names"
    assert fetch_names is not None, "must provide fetch names"

    with sess.as_default(), graph.as_default():
        for tnsr_name in feed_names:
            assert tfx.get_op(tnsr_name, graph), \
                'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
        fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names]
        graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

    return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=None,
                        output_tensor_name_from_signature=None)
Пример #4
0
def _gen_invalid_tensor_or_op_with_graph_pairing():
    tnsr = tf.constant(1427.08, name='someConstOp')
    other_graph = tf.Graph()
    op_name = tnsr.op.name

    # Test get_tensor and get_op with non-associated tensor/op and graph inputs
    _comm_suffix = ' with wrong graph'
    yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph),
                   description='test get_op from tensor' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr, other_graph),
                   description='test get_tensor from tensor' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(tnsr.name, other_graph),
                   description='test get_op from tensor name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr.name, other_graph),
                   description='test get_tensor from tensor name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(tnsr.op, other_graph),
                   description='test get_op from op' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr.op, other_graph),
                   description='test get_tensor from op' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(op_name, other_graph),
                   description='test get_op from op name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(op_name, other_graph),
                   description='test get_tensor from op name' + _comm_suffix)
Пример #5
0
def _gen_invalid_tensor_or_op_with_graph_pairing():
    tnsr = tf.constant(1427.08, name='someConstOp')
    other_graph = tf.Graph()
    op_name = tnsr.op.name

    # Test get_tensor and get_op with non-associated tensor/op and graph inputs
    _comm_suffix = ' with wrong graph'
    yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph),
                   description='test get_op from tensor' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr, other_graph),
                   description='test get_tensor from tensor' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(tnsr.name, other_graph),
                   description='test get_op from tensor name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr.name, other_graph),
                   description='test get_tensor from tensor name' +
                   _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(tnsr.op, other_graph),
                   description='test get_op from op' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr.op, other_graph),
                   description='test get_tensor from op' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(op_name, other_graph),
                   description='test get_op from op name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(op_name, other_graph),
                   description='test get_tensor from op name' + _comm_suffix)
Пример #6
0
def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names):
    assert feed_names is not None, "must provide feed_names"
    assert fetch_names is not None, "must provide fetch names"

    with sess.as_default(), graph.as_default():
        for tnsr_name in feed_names:
            assert tfx.get_op(tnsr_name, graph), \
                'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
        fetches = [
            tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names
        ]
        graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

    return TFInputGraph(graph_def=graph_def,
                        input_tensor_name_from_signature=None,
                        output_tensor_name_from_signature=None)
Пример #7
0
    def test_get_graph_elements(self):
        """ Fetching graph elements by names and other graph elements """

        with IsolatedSession() as issn:
            x = tf.placeholder(tf.double, shape=[], name="x")
            z = tf.add(x, 3, name='z')

            g = issn.graph
            self.assertEqual(tfx.get_tensor(g, z), z)
            self.assertEqual(tfx.get_tensor(g, x), x)
            self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(g, x))
            self.assertEqual("x:0", tfx.tensor_name(g, x))
            self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(g, x))
            self.assertEqual("x", tfx.op_name(g, x))
            self.assertEqual("z", tfx.op_name(g, z))
            self.assertEqual(tfx.tensor_name(g, z), "z:0")
            self.assertEqual(tfx.tensor_name(g, x), "x:0")
Пример #8
0
def _gen_valid_tensor_op_input_combos():
    op_name = 'someConstOp'
    tnsr_name = '{}:0'.format(op_name)
    tnsr = tf.constant(1427.08, name=op_name)
    graph = tnsr.graph

    # Test for op_name
    yield TestCase(data=(op_name, tfx.op_name(tnsr)),
                   description='get op name from tensor (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr, graph)),
                   description='get op name from tensor (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr_name)),
                   description='get op name from tensor name (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr_name, graph)),
                   description='get op name from tensor name (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr.op)),
                   description='get op name from op (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr.op, graph)),
                   description='get op name from op (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(op_name)),
                   description='get op name from op name (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(op_name, graph)),
                   description='get op name from op name (with graph)')

    # Test for tensor_name
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr)),
                   description='get tensor name from tensor (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr, graph)),
                   description='get tensor name from tensor (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)),
                   description='get tensor name from tensor name (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)),
                   description='get tensor name from tensor name (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op)),
                   description='get tensor name from op (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op, graph)),
                   description='get tensor name from op (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)),
                   description='get tensor name from op name (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)),
                   description='get tensor name from op name (with graph)')

    # Test for get_tensor
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr, graph)),
                   description='get tensor from tensor')
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr_name, graph)),
                   description='get tensor from tensor name')
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr.op, graph)),
                   description='get tensor from op')
    yield TestCase(data=(tnsr, tfx.get_tensor(op_name, graph)),
                   description='get tensor from op name')

    # Test for get_op
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)),
                   description='get op from tensor')
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr_name, graph)),
                   description='get op from tensor name')
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr.op, graph)),
                   description='get op from op')
    yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)),
                   description='test op from op name')
Пример #9
0
 def test_invalid_op_inputs_with_wrong_types(self, data, description):
     """ Must fail when provided wrong types """
     with self.assertRaises(TypeError, msg=description):
         tfx.get_op(data, tf.Graph())
Пример #10
0
def _gen_valid_tensor_op_input_combos():
    op_name = 'someConstOp'
    tnsr_name = '{}:0'.format(op_name)
    tnsr = tf.constant(1427.08, name=op_name)
    graph = tnsr.graph

    # Test for op_name
    yield TestCase(data=(op_name, tfx.op_name(tnsr)),
                   description='get op name from tensor (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr, graph)),
                   description='get op name from tensor (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr_name)),
                   description='get op name from tensor name (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr_name, graph)),
                   description='get op name from tensor name (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr.op)),
                   description='get op name from op (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr.op, graph)),
                   description='get op name from op (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(op_name)),
                   description='get op name from op name (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(op_name, graph)),
                   description='get op name from op name (with graph)')

    # Test for tensor_name
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr)),
                   description='get tensor name from tensor (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr, graph)),
                   description='get tensor name from tensor (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)),
                   description='get tensor name from tensor name (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)),
                   description='get tensor name from tensor name (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op)),
                   description='get tensor name from op (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op, graph)),
                   description='get tensor name from op (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)),
                   description='get tensor name from op name (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)),
                   description='get tensor name from op name (with graph)')

    # Test for get_tensor
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr, graph)),
                   description='get tensor from tensor')
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr_name, graph)),
                   description='get tensor from tensor name')
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr.op, graph)),
                   description='get tensor from op')
    yield TestCase(data=(tnsr, tfx.get_tensor(op_name, graph)),
                   description='get tensor from op name')

    # Test for get_op
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)),
                   description='get op from tensor')
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr_name, graph)),
                   description='get op from tensor name')
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr.op, graph)),
                   description='get op from op')
    yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)),
                   description='test op from op name')
Пример #11
0
 def test_invalid_op_inputs_with_wrong_types(self, data, description):
     """ Must fail when provided wrong types """
     with self.assertRaises(TypeError, msg=description):
         tfx.get_op(data, tf.Graph())