Exemple #1
0
def tree_variable(params, tree_config, stats_handle, name, container=None):
    r"""Creates a tree model and returns a handle to it.

  Args:
    params: A TensorForestParams object.
    tree_config: A `Tensor` of type `string`. Serialized proto of the tree.
    stats_handle: Resource handle to the stats object.
    name: A name for the variable.
    container: An optional `string`. Defaults to `""`.

  Returns:
    A `Tensor` of type mutable `string`. The handle to the tree.
  """
    with ops.name_scope(name, "TreeVariable") as name:
        resource_handle = gen_model_ops.decision_tree_resource_handle_op(
            container, name, name=name)

        create_op = gen_model_ops.create_tree_variable(
            resource_handle,
            tree_config,
            params=params.serialized_params_proto)
        is_initialized_op = gen_model_ops.tree_is_initialized_op(
            resource_handle)
        # Adds the variable to the savable list.
        saveable = TreeVariableSavable(params, resource_handle, stats_handle,
                                       create_op,
                                       "tree_checkpoint_{0}".format(name))
        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
        resources.register_resource(resource_handle, create_op,
                                    is_initialized_op)
        return resource_handle
Exemple #2
0
def tree_variable(params, tree_config, stats_handle, name, container=None):
  r"""Creates a tree model and returns a handle to it.

  Args:
    params: A TensorForestParams object.
    tree_config: A `Tensor` of type `string`. Serialized proto of the tree.
    stats_handle: Resource handle to the stats object.
    name: A name for the variable.
    container: An optional `string`. Defaults to `""`.

  Returns:
    A `Tensor` of type mutable `string`. The handle to the tree.
  """
  with ops.name_scope(name, "TreeVariable") as name:
    resource_handle = gen_model_ops.decision_tree_resource_handle_op(
        container, shared_name=name, name=name)

    create_op = gen_model_ops.create_tree_variable(
        resource_handle,
        tree_config,
        params=params.serialized_params_proto)
    is_initialized_op = gen_model_ops.tree_is_initialized_op(resource_handle)
    # Adds the variable to the savable list.
    saveable = TreeVariableSavable(params, resource_handle, stats_handle,
                                   create_op,
                                   resource_handle.name)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    resources.register_resource(resource_handle, create_op, is_initialized_op)
    return resource_handle
Exemple #3
0
 def _create_resource(self):
     if context.executing_eagerly():
         # TODO(allenl): This will leak memory due to kernel caching by the
         # shared_name attribute value (but is better than the alternative of
         # sharing everything by default when executing eagerly; hopefully creating
         # tables in a loop is uncommon).
         shared_name = "tree_variable_%d" % (ops.uid(), )
     else:
         shared_name = self._name
     return gen_model_ops.decision_tree_resource_handle_op(
         self._container, shared_name=shared_name, name=self._name)
Exemple #4
0
 def create_resource(self):
   if context.executing_eagerly():
     # TODO(allenl): This will leak memory due to kernel caching by the
     # shared_name attribute value (but is better than the alternative of
     # sharing everything by default when executing eagerly; hopefully creating
     # tables in a loop is uncommon).
     shared_name = "tree_variable_%d" % (ops.uid(),)
   else:
     shared_name = self._name
   return gen_model_ops.decision_tree_resource_handle_op(
       self._container, shared_name=shared_name, name=self._name)