Exemplo n.º 1
0
  def testConvertTensors(self):
    a = tf.placeholder(tf.int32, [None])
    protomap = _make_signature({"a": a}, {}).inputs

    # convert constant
    in0 = [1, 2, 3]
    output = tensor_info.convert_to_input_tensors(protomap, {"a": in0})
    self.assertEquals(output["a"].dtype, a.dtype)

    # check sparsity
    in1 = tf.sparse_placeholder(tf.int32, [])
    with self.assertRaisesRegexp(TypeError, "dense"):
      tensor_info.convert_to_input_tensors(protomap, {"a": in1})

    # check args mismatch
    with self.assertRaisesRegexp(TypeError, "missing"):
      tensor_info.convert_to_input_tensors(protomap, {"b": in1})
Exemplo n.º 2
0
    def create_apply_graph(self, signature, inputs, name):
        """See `ModuleImpl.create_apply_graph`."""
        signature_def = self._meta_graph.signature_def.get(signature)

        input_tensors = tensor_info.convert_to_input_tensors(
            signature_def.inputs, inputs)

        # Build a input map to feed when importing the apply-graph by augmenting the
        # state_map with the input args. This allows an input to override a tensor
        # from the state-graph.
        feed_map = dict(self._state_map)
        feed_map.update(
            tensor_info.build_input_map(signature_def.inputs, input_tensors))

        # Make state tensors enter the current context. This way the Module can be
        # applied inside a control flow structure such as a while_loop.
        control_flow = self._graph._get_control_flow_context()  # pylint: disable=protected-access
        if control_flow:
            for key, value in sorted(feed_map.items()):
                feed_map[key] = control_flow.AddValue(value)

        # Don't mark the name as used at this point - import_scoped_meta_graph will
        # start using it.
        absolute_scope_name = self._graph.unique_name(name, mark_as_used=False)
        relative_scope_name = absolute_scope_name.split("/")[-1]

        import_collections = [
            # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS
            # ops, however one could create a graph that uses an asset at any other
            # time. As so everytime we bring the tensor with that has the asset
            # filename we must annotate it as so, so later re-exports have that
            # semantic information and can handle it.
            tf.GraphKeys.ASSET_FILEPATHS,
            tf.GraphKeys.COND_CONTEXT,
            tf.GraphKeys.WHILE_CONTEXT,
        ]
        if self._trainable:
            import_collections.extend([tf.GraphKeys.UPDATE_OPS])

        tf.train.import_meta_graph(adapted_meta_graph_for_import(
            self._meta_graph, absolute_scope_name),
                                   input_map=feed_map,
                                   import_scope=relative_scope_name,
                                   restore_collections_predicate=(
                                       lambda key: key in import_collections))
        fix_colocation_after_import(input_map=feed_map,
                                    absolute_import_scope=absolute_scope_name)

        def get_tensor(name):
            # When trying to output an input tensor there are no nodes created within
            # the apply scope. So one must look into the input map.
            try:
                return feed_map[name]
            except KeyError:
                return self._graph.get_tensor_by_name(
                    prepend_name_scope(name, import_scope=absolute_scope_name))

        return tensor_info.build_output_map(signature_def.outputs, get_tensor)
Exemplo n.º 3
0
  def create_apply_graph(self, signature, inputs, name):
    """See `ModuleImpl.create_apply_graph`."""
    signature_def = self._meta_graph.signature_def.get(signature)

    input_tensors = tensor_info.convert_to_input_tensors(
        signature_def.inputs, inputs)

    # Build a input map to feed when importing the apply-graph by augmenting the
    # state_map with the input args. This allows an input to override a tensor
    # from the state-graph.
    feed_map = dict(self._state_map)
    feed_map.update(
        tensor_info.build_input_map(signature_def.inputs, input_tensors))

    # Make state tensors enter the current context. This way the Module can be
    # applied inside a control flow structure such as a while_loop.
    control_flow = self._graph._get_control_flow_context()  # pylint: disable=protected-access
    if control_flow:
      for key, value in sorted(feed_map.items()):
        feed_map[key] = control_flow.AddValue(value)

    # Don't mark the name as used at this point - import_scoped_meta_graph will
    # start using it.
    absolute_scope_name = self._graph.unique_name(name, mark_as_used=False)
    relative_scope_name = absolute_scope_name.split("/")[-1]

    import_collections = [
        # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS
        # ops, however one could create a graph that uses an asset at any other
        # time. As so everytime we bring the tensor with that has the asset
        # filename we must annotate it as so, so later re-exports have that
        # semantic information and can handle it.
        tf.GraphKeys.ASSET_FILEPATHS,
        tf.GraphKeys.COND_CONTEXT,
        tf.GraphKeys.WHILE_CONTEXT,
    ]
    if self._trainable:
      import_collections.extend([tf.GraphKeys.UPDATE_OPS])

    tf.train.import_meta_graph(
        adapted_meta_graph_for_import(self._meta_graph, absolute_scope_name),
        input_map=feed_map,
        import_scope=relative_scope_name,
        restore_collections_predicate=(lambda key: key in import_collections))
    fix_colocation_after_import(input_map=feed_map,
                                absolute_import_scope=absolute_scope_name)

    def get_tensor(name):
      # When trying to output an input tensor there are no nodes created within
      # the apply scope. So one must look into the input map.
      try:
        return feed_map[name]
      except KeyError:
        return self._graph.get_tensor_by_name(
            prepend_name_scope(name, import_scope=absolute_scope_name))

    return tensor_info.build_output_map(signature_def.outputs, get_tensor)