示例#1
0
def _get_expected_result(gin, local_features):
    """
    Running the graph in the :py:obj:`TFInputGraph` object and compute the expected results.
    :param: gin, a :py:obj:`TFInputGraph`
    :return: expected results in NumPy array
    """
    graph = tf.Graph()
    with tf.Session(graph=graph) as sess, graph.as_default():
        # Build test graph and transformers from here
        tf.import_graph_def(gin.graph_def, name='')

        # Build the results
        _results = []
        for row in local_features:
            fetches = [
                tfx.get_tensor(tnsr_name, graph)
                for tnsr_name, _ in _output_mapping.items()
            ]
            feed_dict = {}
            for colname, tnsr_name in _input_mapping.items():
                tnsr = tfx.get_tensor(tnsr_name, graph)
                feed_dict[tnsr] = np.array(row[colname])[np.newaxis, :]

            curr_res = sess.run(fetches, feed_dict=feed_dict)
            _results.append(np.ravel(curr_res))

        expected = np.hstack(_results)

    return expected
def _build_checkpointed_model(session, tmp_dir):
    """
    Writes a model checkpoint in the given directory. The graph is assumed to be generated
     with _build_graph_var.
    """
    ckpt_path_prefix = os.path.join(tmp_dir, 'model_ckpt')
    input_tensor = tfx.get_tensor(_tensor_input_name, session.graph)
    output_tensor = tfx.get_tensor(_tensor_output_name, session.graph)
    w = tfx.get_tensor(_tensor_var_name, session.graph)
    saver = tf.train.Saver(var_list=[w])
    _ = saver.save(session, ckpt_path_prefix, global_step=2702)
    sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)}
    sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)}
    serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=sig_inputs, outputs=sig_outputs)

    # A rather contrived way to add signature def to a meta_graph
    meta_graph_def = tf.train.export_meta_graph()

    # Find the meta_graph file (there should be only one)
    _ckpt_meta_fpaths = glob.glob('{}/*.meta'.format(tmp_dir))
    assert len(_ckpt_meta_fpaths) == 1, \
        'expected only one meta graph, but got {}'.format(','.join(_ckpt_meta_fpaths))
    ckpt_meta_fpath = _ckpt_meta_fpaths[0]

    # Add signature_def to the meta_graph and serialize it
    # This will overwrite the existing meta_graph_def file
    meta_graph_def.signature_def[_serving_sigdef_key].CopyFrom(serving_sigdef)
    with open(ckpt_meta_fpath, mode='wb') as fout:
        fout.write(meta_graph_def.SerializeToString())
示例#3
0
def _build_checkpointed_model(session, tmp_dir):
    """
    Writes a model checkpoint in the given directory. The graph is assumed to be generated
     with _build_graph_var.
    """
    ckpt_path_prefix = os.path.join(tmp_dir, 'model_ckpt')
    input_tensor = tfx.get_tensor(_tensor_input_name, session.graph)
    output_tensor = tfx.get_tensor(_tensor_output_name, session.graph)
    w = tfx.get_tensor(_tensor_var_name, session.graph)
    saver = tf.train.Saver(var_list=[w])
    _ = saver.save(session, ckpt_path_prefix, global_step=2702)
    sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)}
    sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)}
    serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=sig_inputs, outputs=sig_outputs)

    # A rather contrived way to add signature def to a meta_graph
    meta_graph_def = tf.train.export_meta_graph()

    # Find the meta_graph file (there should be only one)
    _ckpt_meta_fpaths = glob.glob('{}/*.meta'.format(tmp_dir))
    assert len(_ckpt_meta_fpaths) == 1, \
        'expected only one meta graph, but got {}'.format(','.join(_ckpt_meta_fpaths))
    ckpt_meta_fpath = _ckpt_meta_fpaths[0]

    # Add signature_def to the meta_graph and serialize it
    # This will overwrite the existing meta_graph_def file
    meta_graph_def.signature_def[_serving_sigdef_key].CopyFrom(serving_sigdef)
    with open(ckpt_meta_fpath, mode='wb') as fout:
        fout.write(meta_graph_def.SerializeToString())
def _get_expected_result(gin, local_features):
    """
    Running the graph in the :py:obj:`TFInputGraph` object and compute the expected results.
    :param: gin, a :py:obj:`TFInputGraph`
    :return: expected results in NumPy array
    """
    graph = tf.Graph()
    with tf.Session(graph=graph) as sess, graph.as_default():
        # Build test graph and transformers from here
        tf.import_graph_def(gin.graph_def, name='')

        # Build the results
        _results = []
        for row in local_features:
            fetches = [tfx.get_tensor(tnsr_name, graph)
                       for tnsr_name, _ in _output_mapping.items()]
            feed_dict = {}
            for colname, tnsr_name in _input_mapping.items():
                tnsr = tfx.get_tensor(tnsr_name, graph)
                feed_dict[tnsr] = np.array(row[colname])[np.newaxis, :]

            curr_res = sess.run(fetches, feed_dict=feed_dict)
            _results.append(np.ravel(curr_res))

        expected = np.hstack(_results)

    return expected
示例#5
0
    def importGraphFunction(self,
                            gfn,
                            input_map=None,
                            prefix="GFN-IMPORT",
                            **gdef_kargs):
        """
        Import a GraphFunction object into the current session.
        The API is similar to :py:meth:`tf.import_graph_def`

        .. _a link: https://www.tensorflow.org/api_docs/python/tf/import_graph_def

        :param gfn: GraphFunction, an object representing a TensorFlow graph and its inputs and outputs
        :param input_map: dict, mapping from input names to existing graph elements
        :param prefix: str, the scope for all the variables in the :py:class:`GraphFunction` elements

                       .. _a link: https://www.tensorflow.org/programmers_guide/variable_scope

        :param gdef_kargs: other keyword elements for TensorFlow's `import_graph_def`
        """
        try:
            del gdef_kargs["return_elements"]
        except KeyError:
            pass
        if input_map is not None:
            assert set(input_map.keys()) <= set(gfn.input_names), \
                "cannot locate provided input elements in the graph"

        input_names = gfn.input_names
        output_names = gfn.output_names
        scope_name = prefix
        if prefix is not None:
            scope_name = prefix.strip()
            if len(scope_name) > 0:
                output_names = [
                    scope_name + '/' + op_name for op_name in gfn.output_names
                ]
                input_names = [
                    scope_name + '/' + op_name for op_name in gfn.input_names
                ]

        # When importing, provide the original output op names
        tf.import_graph_def(gfn.graph_def,
                            input_map=input_map,
                            return_elements=gfn.output_names,
                            name=scope_name,
                            **gdef_kargs)
        feeds = [tfx.get_tensor(name, self.graph) for name in input_names]
        fetches = [tfx.get_tensor(name, self.graph) for name in output_names]
        return (feeds, fetches)
def _check_output(gin, tf_input, expected):
    """
    Takes a TFInputGraph object (assumed to have the input and outputs of the given
    names above) and compares the outcome against some expected outcome.
    """
    graph = tf.Graph()
    graph_def = gin.graph_def
    with tf.Session(graph=graph) as sess:
        tf.import_graph_def(graph_def, name="")
        tgt_feed = tfx.get_tensor(_tensor_input_name, graph)
        tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
        # Run on the testing target
        tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed: tf_input})
        # Working on integers, the calculation should be exact
        assert np.all(tgt_out == expected), (tgt_out, expected)
示例#7
0
def _build_saved_model(session, saved_model_dir):
    """
    Saves a model in a file. The graph is assumed to be generated with _build_graph_novar.
    """
    builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
    input_tensor = tfx.get_tensor(_tensor_input_name, session.graph)
    output_tensor = tfx.get_tensor(_tensor_output_name, session.graph)
    sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)}
    sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)}
    serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=sig_inputs, outputs=sig_outputs)

    builder.add_meta_graph_and_variables(
        session, [_serving_tag], signature_def_map={_serving_sigdef_key: serving_sigdef})
    builder.save()
def _check_output(gin, tf_input, expected):
    """
    Takes a TFInputGraph object (assumed to have the input and outputs of the given
    names above) and compares the outcome against some expected outcome.
    """
    graph = tf.Graph()
    graph_def = gin.graph_def
    with tf.Session(graph=graph) as sess:
        tf.import_graph_def(graph_def, name="")
        tgt_feed = tfx.get_tensor(_tensor_input_name, graph)
        tgt_fetch = tfx.get_tensor(_tensor_output_name, graph)
        # Run on the testing target
        tgt_out = sess.run(tgt_fetch, feed_dict={tgt_feed: tf_input})
        # Working on integers, the calculation should be exact
        assert np.all(tgt_out == expected), (tgt_out, expected)
def _build_saved_model(session, saved_model_dir):
    """
    Saves a model in a file. The graph is assumed to be generated with _build_graph_novar.
    """
    builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
    input_tensor = tfx.get_tensor(_tensor_input_name, session.graph)
    output_tensor = tfx.get_tensor(_tensor_output_name, session.graph)
    sig_inputs = {_tensor_input_signature: tf.saved_model.utils.build_tensor_info(input_tensor)}
    sig_outputs = {_tensor_output_signature: tf.saved_model.utils.build_tensor_info(output_tensor)}
    serving_sigdef = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=sig_inputs, outputs=sig_outputs)

    builder.add_meta_graph_and_variables(
        session, [_serving_tag], signature_def_map={_serving_sigdef_key: serving_sigdef})
    builder.save()
示例#10
0
def _build_with_sig_def(sess, graph, sig_def):
    # pylint: disable=protected-access
    assert sig_def, 'signature_def must not be None'

    with sess.as_default(), graph.as_default():
        feed_mapping = {}
        feed_names = []
        for sigdef_key, tnsr_info in sig_def.inputs.items():
            tnsr_name = tnsr_info.name
            feed_mapping[sigdef_key] = tnsr_name
            feed_names.append(tnsr_name)

        fetch_mapping = {}
        fetch_names = []
        for sigdef_key, tnsr_info in sig_def.outputs.items():
            tnsr_name = tnsr_info.name
            fetch_mapping[sigdef_key] = tnsr_name
            fetch_names.append(tnsr_name)

        for tnsr_name in feed_names:
            assert tfx.get_op(tnsr_name, graph), \
                'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
        fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names]
        graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

    return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=feed_mapping,
                        output_tensor_name_from_signature=fetch_mapping)
示例#11
0
    def _transform(self, dataset):
        if any([field.dataType == DoubleType() for field in dataset.schema]):
            logger.warning("Detected DoubleType columns in dataframe passed to transform(). In "
                           "Deep Learning Pipelines 1.0 and above, DoubleType columns can only be "
                           "fed to input tensors of type tf.float64. To feed dataframe data to "
                           "tensors of other types (e.g. tf.float32, tf.int32, tf.int64), use the "
                           "corresponding Spark SQL data types (FloatType, IntegerType, LongType).")

        graph_def = self._optimize_for_inference()
        input_mapping = self.getInputMapping()
        output_mapping = self.getOutputMapping()

        graph = tf.Graph()
        with tf.Session(graph=graph):
            analyzed_df = tfs.analyze(dataset)
            out_tnsr_op_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping]
            # Load graph
            tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names)
            # Feed dict maps from placeholder name to DF column name
            feed_dict = {tfx.op_name(tnsr_name): col_name for col_name, tnsr_name in input_mapping}
            fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in out_tnsr_op_names]
            out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict)
            # We still have to rename output columns
            for tnsr_name, new_colname in output_mapping:
                old_colname = tfx.op_name(tnsr_name, graph)
                if old_colname != new_colname:
                    out_df = out_df.withColumnRenamed(old_colname, new_colname)

        return out_df
    def _transform(self, dataset):
        if len([field for field in dataset.schema if field.dataType == DoubleType()]) > 0:
            logger.warn("Detected DoubleType columns in dataframe passed to transform(). In "
                        "Deep Learning Pipelines 1.0 and above, DoubleType columns can only be "
                        "fed to input tensors of type tf.float64. To feed dataframe data to "
                        "tensors of other types (e.g. tf.float32, tf.int32, tf.int64), use the "
                        "corresponding Spark SQL data types (FloatType, IntegerType, LongType).")

        graph_def = self._optimize_for_inference()
        input_mapping = self.getInputMapping()
        output_mapping = self.getOutputMapping()

        graph = tf.Graph()
        with tf.Session(graph=graph):
            analyzed_df = tfs.analyze(dataset)
            out_tnsr_op_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping]
            # Load graph
            tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names)
            # Feed dict maps from placeholder name to DF column name
            feed_dict = {tfx.op_name(tnsr_name): col_name for col_name, tnsr_name in input_mapping}
            fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in out_tnsr_op_names]
            out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict)
            # We still have to rename output columns
            for tnsr_name, new_colname in output_mapping:
                old_colname = tfx.op_name(tnsr_name, graph)
                if old_colname != new_colname:
                    out_df = out_df.withColumnRenamed(old_colname, new_colname)

        return out_df
示例#13
0
    def _transform(self, dataset):
        graph_def = self._optimize_for_inference()
        input_mapping = self.getInputMapping()
        output_mapping = self.getOutputMapping()

        graph = tf.Graph()
        with tf.Session(graph=graph):
            analyzed_df = tfs.analyze(dataset)
            out_tnsr_op_names = [
                tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping
            ]
            # Load graph
            tf.import_graph_def(graph_def=graph_def,
                                name='',
                                return_elements=out_tnsr_op_names)

            # Feed dict maps from placeholder name to DF column name
            feed_dict = {
                self._getSparkDlOpName(tnsr_name): col_name
                for col_name, tnsr_name in input_mapping
            }
            fetches = [
                tfx.get_tensor(tnsr_name, graph)
                for tnsr_name in out_tnsr_op_names
            ]

            out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict)
            # We still have to rename output columns
            for tnsr_name, new_colname in output_mapping:
                old_colname = tfx.op_name(tnsr_name, graph)
                if old_colname != new_colname:
                    out_df = out_df.withColumnRenamed(old_colname, new_colname)

        return out_df
示例#14
0
def _build_with_sig_def(sess, graph, sig_def):
    # pylint: disable=protected-access
    assert sig_def, 'signature_def must not be None'

    with sess.as_default(), graph.as_default():
        feed_mapping = {}
        feed_names = []
        for sigdef_key, tnsr_info in sig_def.inputs.items():
            tnsr_name = tnsr_info.name
            feed_mapping[sigdef_key] = tnsr_name
            feed_names.append(tnsr_name)

        fetch_mapping = {}
        fetch_names = []
        for sigdef_key, tnsr_info in sig_def.outputs.items():
            tnsr_name = tnsr_info.name
            fetch_mapping[sigdef_key] = tnsr_name
            fetch_names.append(tnsr_name)

        for tnsr_name in feed_names:
            assert tfx.get_op(tnsr_name, graph), \
                'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
        fetches = [
            tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names
        ]
        graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

    return TFInputGraph(graph_def=graph_def,
                        input_tensor_name_from_signature=feed_mapping,
                        output_tensor_name_from_signature=fetch_mapping)
示例#15
0
    def test_get_graph_elements(self):
        """ Fetching graph elements by names and other graph elements """

        with IsolatedSession() as issn:
            x = tf.placeholder(tf.double, shape=[], name="x")
            z = tf.add(x, 3, name='z')

            g = issn.graph
            self.assertEqual(tfx.get_tensor(g, z), z)
            self.assertEqual(tfx.get_tensor(g, x), x)
            self.assertEqual(g.get_tensor_by_name("x:0"), tfx.get_tensor(g, x))
            self.assertEqual("x:0", tfx.tensor_name(g, x))
            self.assertEqual(g.get_operation_by_name("x"), tfx.get_op(g, x))
            self.assertEqual("x", tfx.op_name(g, x))
            self.assertEqual("z", tfx.op_name(g, z))
            self.assertEqual(tfx.tensor_name(g, z), "z:0")
            self.assertEqual(tfx.tensor_name(g, x), "x:0")
示例#16
0
    def importGraphFunction(self, gfn, input_map=None, prefix="GFN-IMPORT", **gdef_kargs):
        """
        Import a GraphFunction object into the current session.
        The API is similar to :py:meth:`tf.import_graph_def`

        .. _a link: https://www.tensorflow.org/api_docs/python/tf/import_graph_def

        :param gfn: GraphFunction, an object representing a TensorFlow graph and its inputs and outputs
        :param input_map: dict, mapping from input names to existing graph elements
        :param prefix: str, the scope for all the variables in the :py:class:`GraphFunction` elements

                       .. _a link: https://www.tensorflow.org/programmers_guide/variable_scope

        :param gdef_kargs: other keyword elements for TensorFlow's `import_graph_def`
        """
        try:
            del gdef_kargs["return_elements"]
        except KeyError:
            pass
        if input_map is not None:
            assert set(input_map.keys()) <= set(gfn.input_names), \
                "cannot locate provided input elements in the graph"

        input_names = gfn.input_names
        output_names = gfn.output_names
        scope_name = prefix
        if prefix is not None:
            scope_name = prefix.strip()
            if len(scope_name) > 0:
                output_names = [
                    scope_name + '/' + op_name for op_name in gfn.output_names]
                input_names = [
                    scope_name + '/' + op_name for op_name in gfn.input_names]

        # When importing, provide the original output op names
        tf.import_graph_def(gfn.graph_def,
                            input_map=input_map,
                            return_elements=gfn.output_names,
                            name=scope_name,
                            **gdef_kargs)
        feeds = [tfx.get_tensor(self.graph, name) for name in input_names]
        fetches = [tfx.get_tensor(self.graph, name) for name in output_names]
        return (feeds, fetches)
 def _get_placeholder_types(self, user_graph_def):
     """ Returns a list of placeholder type enums for the input nodes """
     user_graph = tf.Graph()
     with user_graph.as_default():
         # Load user-specified graph into memory, then get the data type of each input node
         tf.import_graph_def(user_graph_def, name="")
         res = []
         for _, tensor_name in self.getInputMapping():
             placeholder_type = tfx.get_tensor(tensor_name, user_graph).dtype.as_datatype_enum
             res.append(placeholder_type)
     return res
示例#18
0
 def _get_placeholder_types(self, user_graph_def):
     """ Returns a list of placeholder type enums for the input nodes """
     user_graph = tf.Graph()
     with user_graph.as_default():
         # Load user-specified graph into memory, then get the data type of each input node
         tf.import_graph_def(user_graph_def, name="")
         res = []
         for _, tensor_name in self.getInputMapping():
             placeholder_type = tfx.get_tensor(tensor_name, user_graph).dtype.as_datatype_enum
             res.append(placeholder_type)
     return res
示例#19
0
def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names):
    assert feed_names is not None, "must provide feed_names"
    assert fetch_names is not None, "must provide fetch names"

    with sess.as_default(), graph.as_default():
        for tnsr_name in feed_names:
            assert tfx.get_op(tnsr_name, graph), \
                'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
        fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names]
        graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

    return TFInputGraph(graph_def=graph_def, input_tensor_name_from_signature=None,
                        output_tensor_name_from_signature=None)
示例#20
0
def _gen_invalid_tensor_or_op_with_graph_pairing():
    tnsr = tf.constant(1427.08, name='someConstOp')
    other_graph = tf.Graph()
    op_name = tnsr.op.name

    # Test get_tensor and get_op with non-associated tensor/op and graph inputs
    _comm_suffix = ' with wrong graph'
    yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph),
                   description='test get_op from tensor' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr, other_graph),
                   description='test get_tensor from tensor' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(tnsr.name, other_graph),
                   description='test get_op from tensor name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr.name, other_graph),
                   description='test get_tensor from tensor name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(tnsr.op, other_graph),
                   description='test get_op from op' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr.op, other_graph),
                   description='test get_tensor from op' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(op_name, other_graph),
                   description='test get_op from op name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(op_name, other_graph),
                   description='test get_tensor from op name' + _comm_suffix)
示例#21
0
def _gen_invalid_tensor_or_op_with_graph_pairing():
    tnsr = tf.constant(1427.08, name='someConstOp')
    other_graph = tf.Graph()
    op_name = tnsr.op.name

    # Test get_tensor and get_op with non-associated tensor/op and graph inputs
    _comm_suffix = ' with wrong graph'
    yield TestCase(data=lambda: tfx.get_op(tnsr, other_graph),
                   description='test get_op from tensor' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr, other_graph),
                   description='test get_tensor from tensor' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(tnsr.name, other_graph),
                   description='test get_op from tensor name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr.name, other_graph),
                   description='test get_tensor from tensor name' +
                   _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(tnsr.op, other_graph),
                   description='test get_op from op' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(tnsr.op, other_graph),
                   description='test get_tensor from op' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_op(op_name, other_graph),
                   description='test get_op from op name' + _comm_suffix)
    yield TestCase(data=lambda: tfx.get_tensor(op_name, other_graph),
                   description='test get_tensor from op name' + _comm_suffix)
示例#22
0
def _build_with_feeds_fetches(sess, graph, feed_names, fetch_names):
    assert feed_names is not None, "must provide feed_names"
    assert fetch_names is not None, "must provide fetch names"

    with sess.as_default(), graph.as_default():
        for tnsr_name in feed_names:
            assert tfx.get_op(tnsr_name, graph), \
                'requested tensor {} but found none in graph {}'.format(tnsr_name, graph)
        fetches = [
            tfx.get_tensor(tnsr_name, graph) for tnsr_name in fetch_names
        ]
        graph_def = tfx.strip_and_freeze_until(fetches, graph, sess)

    return TFInputGraph(graph_def=graph_def,
                        input_tensor_name_from_signature=None,
                        output_tensor_name_from_signature=None)
示例#23
0
    def _addCastOps(self, user_graph_def):
        """
        Given a GraphDef object corresponding to a user-specified graph G, creates a copy G'
        of G with ops injected before each input node. The injected ops allow the input nodes of G'
        to accept tf.float64 input fed from Spark, casting float64 input into the datatype
        requested by each input node.

        :return: GraphDef representing the copied, modified graph.
        """
        # Load user-specified graph into memory
        user_graph = tf.Graph()
        with user_graph.as_default():
            tf.import_graph_def(user_graph_def, name="")

        # Build a subgraph containing our injected ops
        # TODO: Cheap optimization: if all input tensors are of type float64, just do nothing here
        injected_op_subgraph = tf.Graph()
        # Maps names of input tensors in our original graph to outputs of the injected-op subgraph
        input_map = {}
        with injected_op_subgraph.as_default():
            with tf.name_scope(self.SPARKDL_OP_SCOPE):
                for _, orig_tensor_name in self.getInputMapping():
                    orig_tensor = tfx.get_tensor(orig_tensor_name, user_graph)
                    # Create placeholder with same shape as original input tensor, but that accepts
                    # float64 input from Spark.
                    spark_placeholder = tf.placeholder(
                        tf.float64,
                        shape=orig_tensor.shape,
                        name=tfx.op_name(orig_tensor_name))
                    # If the original tensor was of type float64, just pass through the Spark input
                    if orig_tensor.dtype == tf.float64:
                        input_map[orig_tensor_name] = spark_placeholder
                    # Otherwise, cast the Spark input to the datatype of the original tensor
                    else:
                        input_map[orig_tensor_name] = tf.cast(
                            spark_placeholder, dtype=orig_tensor.dtype)
            tf.import_graph_def(graph_def=user_graph_def,
                                input_map=input_map,
                                name="")
        return injected_op_subgraph.as_graph_def(add_shapes=True)
示例#24
0
 def test_invalid_tensor_inputs_with_wrong_types(self, data, description):
     """ Must fail when provided wrong types """
     with self.assertRaises(TypeError, msg=description):
         tfx.get_tensor(data, tf.Graph())
示例#25
0
 def test_invalid_tensor_inputs_with_wrong_types(self, data, description):
     """ Must fail when provided wrong types """
     with self.assertRaises(TypeError, msg=description):
         tfx.get_tensor(data, tf.Graph())
示例#26
0
def _gen_valid_tensor_op_input_combos():
    op_name = 'someConstOp'
    tnsr_name = '{}:0'.format(op_name)
    tnsr = tf.constant(1427.08, name=op_name)
    graph = tnsr.graph

    # Test for op_name
    yield TestCase(data=(op_name, tfx.op_name(tnsr)),
                   description='get op name from tensor (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr, graph)),
                   description='get op name from tensor (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr_name)),
                   description='get op name from tensor name (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr_name, graph)),
                   description='get op name from tensor name (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr.op)),
                   description='get op name from op (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr.op, graph)),
                   description='get op name from op (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(op_name)),
                   description='get op name from op name (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(op_name, graph)),
                   description='get op name from op name (with graph)')

    # Test for tensor_name
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr)),
                   description='get tensor name from tensor (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr, graph)),
                   description='get tensor name from tensor (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)),
                   description='get tensor name from tensor name (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)),
                   description='get tensor name from tensor name (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op)),
                   description='get tensor name from op (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op, graph)),
                   description='get tensor name from op (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)),
                   description='get tensor name from op name (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)),
                   description='get tensor name from op name (with graph)')

    # Test for get_tensor
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr, graph)),
                   description='get tensor from tensor')
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr_name, graph)),
                   description='get tensor from tensor name')
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr.op, graph)),
                   description='get tensor from op')
    yield TestCase(data=(tnsr, tfx.get_tensor(op_name, graph)),
                   description='get tensor from op name')

    # Test for get_op
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)),
                   description='get op from tensor')
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr_name, graph)),
                   description='get op from tensor name')
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr.op, graph)),
                   description='get op from op')
    yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)),
                   description='test op from op name')
示例#27
0
def _gen_valid_tensor_op_input_combos():
    op_name = 'someConstOp'
    tnsr_name = '{}:0'.format(op_name)
    tnsr = tf.constant(1427.08, name=op_name)
    graph = tnsr.graph

    # Test for op_name
    yield TestCase(data=(op_name, tfx.op_name(tnsr)),
                   description='get op name from tensor (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr, graph)),
                   description='get op name from tensor (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr_name)),
                   description='get op name from tensor name (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr_name, graph)),
                   description='get op name from tensor name (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr.op)),
                   description='get op name from op (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(tnsr.op, graph)),
                   description='get op name from op (with graph)')
    yield TestCase(data=(op_name, tfx.op_name(op_name)),
                   description='get op name from op name (no graph)')
    yield TestCase(data=(op_name, tfx.op_name(op_name, graph)),
                   description='get op name from op name (with graph)')

    # Test for tensor_name
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr)),
                   description='get tensor name from tensor (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr, graph)),
                   description='get tensor name from tensor (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)),
                   description='get tensor name from tensor name (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)),
                   description='get tensor name from tensor name (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op)),
                   description='get tensor name from op (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr.op, graph)),
                   description='get tensor name from op (with graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)),
                   description='get tensor name from op name (no graph)')
    yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)),
                   description='get tensor name from op name (with graph)')

    # Test for get_tensor
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr, graph)),
                   description='get tensor from tensor')
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr_name, graph)),
                   description='get tensor from tensor name')
    yield TestCase(data=(tnsr, tfx.get_tensor(tnsr.op, graph)),
                   description='get tensor from op')
    yield TestCase(data=(tnsr, tfx.get_tensor(op_name, graph)),
                   description='get tensor from op name')

    # Test for get_op
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)),
                   description='get op from tensor')
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr_name, graph)),
                   description='get op from tensor name')
    yield TestCase(data=(tnsr.op, tfx.get_op(tnsr.op, graph)),
                   description='get op from op')
    yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)),
                   description='test op from op name')