예제 #1
0
    def training_graph(self,
                       input_data,
                       input_labels,
                       random_seed,
                       data_spec,
                       sparse_features=None,
                       input_weights=None):
        if input_weights is None:
            input_weights = []

        sparse_indices = []
        sparse_values = []
        sparse_shape = []
        if sparse_features is not None:
            sparse_indices = sparse_features.indices
            sparse_values = sparse_features.values
            sparse_shape = sparse_features.dense_shape

        if input_data is None:
            input_data = []

        leaf_ids = model_ops.traverse_tree_v4(
            self.variables.tree,
            input_data,
            sparse_indices,
            sparse_values,
            sparse_shape,
            input_spec=data_spec.SerializeToString(),
            params=self.params.serialized_params_proto)

        update_model = model_ops.update_model_v4(
            self.variables.tree,
            leaf_ids,
            input_labels,
            input_weights,
            params=self.params.serialized_params_proto)

        finished_nodes = stats_ops.process_input_v4(
            self.variables.tree,
            self.variables.stats,
            input_data,
            sparse_indices,
            sparse_values,
            sparse_shape,
            input_labels,
            input_weights,
            leaf_ids,
            input_spec=data_spec.SerializeToString(),
            random_seed=random_seed,
            params=self.params.serialized_params_proto)

        with ops.control_dependencies([update_model]):
            return stats_ops.grow_tree_v4(
                self.variables.tree,
                self.variables.stats,
                finished_nodes,
                params=self.params.serialized_params_proto)
예제 #2
0
  def training_graph(self, input_data,
                     input_labels,
                     random_seed,
                     data_spec,
                     sparse_features=None,
                     input_weights=None):
    if input_weights is None:
      input_weights = []

    sparse_indices = []
    sparse_values = []
    sparse_shape = []
    if sparse_features is not None:
      sparse_indices = sparse_features.indices
      sparse_values = sparse_features.values
      sparse_shape = sparse_features.dense_shape

    if input_data is None:
      input_data = []

    leaf_ids = model_ops.traverse_tree_v4(
        self.variables.tree,
        input_data,
        sparse_indices,
        sparse_values,
        sparse_shape,
        input_spec=data_spec.SerializeToString(),
        params=self.params.serialized_params_proto)

    update_model = model_ops.update_model_v4(
        self.variables.tree,
        leaf_ids,
        input_labels,
        input_weights,
        params=self.params.serialized_params_proto)

    finished_nodes = stats_ops.process_input_v4(
        self.variables.tree,
        self.variables.stats,
        input_data,
        sparse_indices,
        sparse_values,
        sparse_shape,
        input_labels,
        input_weights,
        leaf_ids,
        input_spec=data_spec.SerializeToString(),
        random_seed=random_seed,
        params=self.params.serialized_params_proto)

    with ops.control_dependencies([update_model]):
      return stats_ops.grow_tree_v4(
          self.variables.tree,
          self.variables.stats,
          finished_nodes,
          params=self.params.serialized_params_proto)
예제 #3
0
    def training_graph(self,
                       input_data,
                       input_labels,
                       random_seed,
                       data_spec,
                       sparse_features=None,
                       input_weights=None):
        """Constructs a TF graph for training a random tree.

    Args:
      input_data: A tensor or placeholder for input data.
      input_labels: A tensor or placeholder for labels associated with
        input_data.
      random_seed: The random number generator seed to use for this tree.  0
        means use the current time as the seed.
      data_spec: A data_ops.TensorForestDataSpec object specifying the
        original feature/columns of the data.
      sparse_features: A tf.SparseTensor for sparse input data.
      input_weights: A float tensor or placeholder holding per-input weights,
        or None if all inputs are to be weighted equally.

    Returns:
      The last op in the random tree training graph.
    """
        # TODO(gilberth): Use this.
        unused_epoch = math_ops.to_int32(get_epoch_variable())

        if input_weights is None:
            input_weights = []

        sparse_indices = []
        sparse_values = []
        sparse_shape = []
        if sparse_features is not None:
            sparse_indices = sparse_features.indices
            sparse_values = sparse_features.values
            sparse_shape = sparse_features.dense_shape

        if input_data is None:
            input_data = []

        leaf_ids = model_ops.traverse_tree_v4(
            self.variables.tree,
            input_data,
            sparse_indices,
            sparse_values,
            sparse_shape,
            input_spec=data_spec.SerializeToString(),
            params=self.params.serialized_params_proto)

        update_model = model_ops.update_model_v4(
            self.variables.tree,
            leaf_ids,
            input_labels,
            input_weights,
            params=self.params.serialized_params_proto)

        finished_nodes = stats_ops.process_input_v4(
            self.variables.tree,
            self.variables.stats,
            input_data,
            sparse_indices,
            sparse_values,
            sparse_shape,
            input_labels,
            input_weights,
            leaf_ids,
            input_spec=data_spec.SerializeToString(),
            random_seed=random_seed,
            params=self.params.serialized_params_proto)

        with ops.control_dependencies([update_model]):
            return stats_ops.grow_tree_v4(
                self.variables.tree,
                self.variables.stats,
                finished_nodes,
                params=self.params.serialized_params_proto)
예제 #4
0
  def training_graph(self,
                     input_data,
                     input_labels,
                     random_seed,
                     data_spec,
                     sparse_features=None,
                     input_weights=None):

    """Constructs a TF graph for training a random tree.

    Args:
      input_data: A tensor or placeholder for input data.
      input_labels: A tensor or placeholder for labels associated with
        input_data.
      random_seed: The random number generator seed to use for this tree.  0
        means use the current time as the seed.
      data_spec: A data_ops.TensorForestDataSpec object specifying the
        original feature/columns of the data.
      sparse_features: A tf.SparseTensor for sparse input data.
      input_weights: A float tensor or placeholder holding per-input weights,
        or None if all inputs are to be weighted equally.

    Returns:
      The last op in the random tree training graph.
    """
    # TODO(gilberth): Use this.
    unused_epoch = math_ops.to_int32(get_epoch_variable())

    if input_weights is None:
      input_weights = []

    sparse_indices = []
    sparse_values = []
    sparse_shape = []
    if sparse_features is not None:
      sparse_indices = sparse_features.indices
      sparse_values = sparse_features.values
      sparse_shape = sparse_features.dense_shape

    if input_data is None:
      input_data = []

    leaf_ids = model_ops.traverse_tree_v4(
        self.variables.tree,
        input_data,
        sparse_indices,
        sparse_values,
        sparse_shape,
        input_spec=data_spec.SerializeToString(),
        params=self.params.serialized_params_proto)

    update_model = model_ops.update_model_v4(
        self.variables.tree,
        leaf_ids,
        input_labels,
        input_weights,
        params=self.params.serialized_params_proto)

    finished_nodes = stats_ops.process_input_v4(
        self.variables.tree,
        self.variables.stats,
        input_data,
        sparse_indices,
        sparse_values,
        sparse_shape,
        input_labels,
        input_weights,
        leaf_ids,
        input_spec=data_spec.SerializeToString(),
        random_seed=random_seed,
        params=self.params.serialized_params_proto)

    with ops.control_dependencies([update_model]):
      return stats_ops.grow_tree_v4(
          self.variables.tree,
          self.variables.stats,
          finished_nodes,
          params=self.params.serialized_params_proto)