def _build_with_sig_def(sess, graph, sig_def): # pylint: disable=protected-access assert sig_def, 'signature_def must not be None' with sess.as_default(), graph.as_default(): feed_mapping = {} feed_names = [] for sigdef_key, tnsr_info in sig_def.inputs.items(): tnsr_name = tnsr_info.name feed_mapping[sigdef_key] = tnsr_name feed_names.append(tnsr_name) fetch_mapping = {} fetch_names = [] for sigdef_key, tnsr_info in sig_def.outputs.items(): tnsr_name = tnsr_info.name fetch_mapping[sigdef_key] = tnsr_name fetch_names.append(tnsr_name) for tnsr_name in feed_names: assert tfx.get_op(tnsr_name, graph), \ 'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) fetches = [ tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names ] graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=feed_mapping, output_tensor_name_from_signature=fetch_mapping)
def _build_with_sig_def(sess, graph, sig_def): # pylint: disable=protected-access assert sig_def, 'signature_def must not be None' with sess.as_default(), graph.as_default(): feed_mapping = {} feed_names = [] for sigdef_key, tnsr_info in sig_def.inputs.items(): tnsr_name = tnsr_info.name feed_mapping[sigdef_key] = tnsr_name feed_names.append(tnsr_name) fetch_mapping = {} fetch_names = [] for sigdef_key, tnsr_info in sig_def.outputs.items(): tnsr_name = tnsr_info.name fetch_mapping[sigdef_key] = tnsr_name fetch_names.append(tnsr_name) for tnsr_name in feed_names: assert tfx.get_op(tnsr_name, graph), \ 'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names] graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=feed_mapping, output_tensor_name_from_signature=fetch_mapping)
def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names): assert feed_names is not None, "must provide feed_names" assert fetch_names is not None, "must provide fetch names" with sess.as_default(), graph.as_default(): for tnsr_name in feed_names: assert tfx.get_op(tnsr_name, graph), \ 'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names] graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=None, output_tensor_name_from_signature=None)
def _gen_invalid_tensor_or_op_with_graph_pairing(): tnsr = tf.constant(1427.08, name='someConstOp') other_graph = tf.Graph() op_name = tnsr.op.name # Test get_tensor and get_op with non-associated tensor/op and graph inputs _comm_suffix = ' with wrong graph' yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph), description='test get_op from tensor' + _comm_suffix) yield TestCase(data=lambda: tfx.get_tensor(tnsr, other_graph), description='test get_tensor from tensor' + _comm_suffix) yield TestCase(data=lambda: tfx.get_op(tnsr.name, other_graph), description='test get_op from tensor name' + _comm_suffix) yield TestCase(data=lambda: tfx.get_tensor(tnsr.name, other_graph), description='test get_tensor from tensor name' + _comm_suffix) yield TestCase(data=lambda: tfx.get_op(tnsr.op, other_graph), description='test get_op from op' + _comm_suffix) yield TestCase(data=lambda: tfx.get_tensor(tnsr.op, other_graph), description='test get_tensor from op' + _comm_suffix) yield TestCase(data=lambda: tfx.get_op(op_name, other_graph), description='test get_op from op name' + _comm_suffix) yield TestCase(data=lambda: tfx.get_tensor(op_name, other_graph), description='test get_tensor from op name' + _comm_suffix)
def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names): assert feed_names is not None, "must provide feed_names" assert fetch_names is not None, "must provide fetch names" with sess.as_default(), graph.as_default(): for tnsr_name in feed_names: assert tfx.get_op(tnsr_name, graph), \ 'requested tensor {} but found none in graph {}'.format(tnsr_name, graph) fetches = [ tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names ] graph_def = tfx.strip_and_freeze_until(fetches, graph, sess) return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=None, output_tensor_name_from_signature=None)
def test_get_graph_elements(self): """ Fetching graph elements by names and other graph elements """ with IsolatedSession() as issn: x = tf.placeholder(tf.double, shape=[], name="x") z = tf.add(x, 3, name='z') g = issn.graph self.assertEqual(tfx.get_tensor(g, z), z) self.assertEqual(tfx.get_tensor(g, x), x) self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(g, x)) self.assertEqual("x:0", tfx.tensor_name(g, x)) self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(g, x)) self.assertEqual("x", tfx.op_name(g, x)) self.assertEqual("z", tfx.op_name(g, z)) self.assertEqual(tfx.tensor_name(g, z), "z:0") self.assertEqual(tfx.tensor_name(g, x), "x:0")
def _gen_valid_tensor_op_input_combos(): op_name = 'someConstOp' tnsr_name = '{}:0'.format(op_name) tnsr = tf.constant(1427.08, name=op_name) graph = tnsr.graph # Test for op_name yield TestCase(data=(op_name, tfx.op_name(tnsr)), description='get op name from tensor (no graph)') yield TestCase(data=(op_name, tfx.op_name(tnsr, graph)), description='get op name from tensor (with graph)') yield TestCase(data=(op_name, tfx.op_name(tnsr_name)), description='get op name from tensor name (no graph)') yield TestCase(data=(op_name, tfx.op_name(tnsr_name, graph)), description='get op name from tensor name (with graph)') yield TestCase(data=(op_name, tfx.op_name(tnsr.op)), description='get op name from op (no graph)') yield TestCase(data=(op_name, tfx.op_name(tnsr.op, graph)), description='get op name from op (with graph)') yield TestCase(data=(op_name, tfx.op_name(op_name)), description='get op name from op name (no graph)') yield TestCase(data=(op_name, tfx.op_name(op_name, graph)), description='get op name from op name (with graph)') # Test for tensor_name yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr)), description='get tensor name from tensor (no graph)') yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr, graph)), description='get tensor name from tensor (with graph)') yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)), description='get tensor name from tensor name (no graph)') yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)), description='get tensor name from tensor name (with graph)') yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op)), description='get tensor name from op (no graph)') yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op, graph)), description='get tensor name from op (with graph)') yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)), description='get tensor name from op name (no graph)') yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)), description='get tensor name from op name (with graph)') # Test for get_tensor yield TestCase(data=(tnsr, tfx.get_tensor(tnsr, graph)), description='get tensor from tensor') yield TestCase(data=(tnsr, tfx.get_tensor(tnsr_name, graph)), description='get tensor from tensor name') yield TestCase(data=(tnsr, tfx.get_tensor(tnsr.op, graph)), description='get tensor from op') yield TestCase(data=(tnsr, tfx.get_tensor(op_name, graph)), description='get tensor from op name') # Test for get_op yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)), description='get op from tensor') yield TestCase(data=(tnsr.op, tfx.get_op(tnsr_name, graph)), description='get op from tensor name') yield TestCase(data=(tnsr.op, tfx.get_op(tnsr.op, graph)), description='get op from op') yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)), description='test op from op name')
def test_invalid_op_inputs_with_wrong_types(self, data, description): """ Must fail when provided wrong types """ with self.assertRaises(TypeError, msg=description): tfx.get_op(data, tf.Graph())