def testMissingLabel(self):
    labels = [0, 1, -1, 3]
    with self.test_session():
      (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
       pcw_totals_indices, pcw_totals_sums, _,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           self.input_data, [], [], [],
           labels, [],
           self.tree,
           self.tree_thresholds,
           self.node_map,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           self.current_epoch,
           input_spec=self.data_spec,
           num_classes=5,
           regression=False))

      self.assertAllEqual(
          [[3., 1., 1., 0., 1.], [2., 1., 1., 0., 0.], [1., 0., 0., 0., 1.]],
          pcw_node_sums.eval())
      self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval())
      self.assertAllEqual([1., 1.], pcw_splits_sums.eval())
      self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval())
      self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
  def testSimpleWeighted(self):
    with self.test_session():
      input_weights = [1.0, 2.0, 3.0, 4.0]
      (pcw_node_sums, pcw_node_squares, pcw_splits_indices, pcw_splits_sums,
       pcw_splits_squares, pcw_totals_indices, pcw_totals_sums,
       pcw_totals_squares,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           self.input_data, [], [], [],
           self.data_spec,
           self.input_labels,
           input_weights,
           self.tree,
           self.tree_thresholds,
           self.node_map,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           self.current_epoch,
           num_classes=2,
           regression=True))

      self.assertAllEqual([[10., 33.], [3., 15.], [7., 18.]],
                          pcw_node_sums.eval())
      self.assertAllEqual([[10., 129.], [3., 81.], [7., 48.]],
                          pcw_node_squares.eval())
      self.assertAllEqual([[0, 0]], pcw_splits_indices.eval())
      self.assertAllEqual([[1., 3.]], pcw_splits_sums.eval())
      self.assertAllEqual([[1., 9.]], pcw_splits_squares.eval())
      self.assertAllEqual([[0]], pcw_totals_indices.eval())
      self.assertAllEqual([[2., 9.]], pcw_totals_sums.eval())
      self.assertAllEqual([[2., 45.]], pcw_totals_squares.eval())
      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
Пример #3
0
  def testSimpleWeighted(self):
    with self.test_session():
      input_weights = [1.0, 2.0, 3.0, 4.0]
      (pcw_node_sums, pcw_node_squares, pcw_splits_indices, pcw_splits_sums,
       pcw_splits_squares, pcw_totals_indices, pcw_totals_sums,
       pcw_totals_squares,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           self.input_data, [], [], [],
           self.input_labels,
           input_weights,
           self.tree,
           self.tree_thresholds,
           self.node_map,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           self.current_epoch,
           input_spec=self.data_spec,
           num_classes=2,
           regression=True))

      self.assertAllEqual([[10., 33.], [3., 15.], [7., 18.]],
                          pcw_node_sums.eval())
      self.assertAllEqual([[10., 129.], [3., 81.], [7., 48.]],
                          pcw_node_squares.eval())
      self.assertAllEqual([[0, 0]], pcw_splits_indices.eval())
      self.assertAllEqual([[1., 3.]], pcw_splits_sums.eval())
      self.assertAllEqual([[1., 9.]], pcw_splits_squares.eval())
      self.assertAllEqual([[0]], pcw_totals_indices.eval())
      self.assertAllEqual([[2., 9.]], pcw_totals_sums.eval())
      self.assertAllEqual([[2., 45.]], pcw_totals_squares.eval())
      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
  def testNoAccumulators(self):
    with self.test_session():
      (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
       pcw_totals_indices, pcw_totals_sums, _,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           self.input_data, [], [], [],
           self.data_spec,
           self.input_labels, [],
           self.tree,
           self.tree_thresholds, [-1] * 3,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           self.current_epoch,
           num_classes=5,
           regression=False))

      self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.],
                           [2., 0., 0., 1., 1.]],
                          pcw_node_sums.eval())
      self.assertEquals((0, 3), pcw_splits_indices.eval().shape)
      self.assertAllEqual([], pcw_splits_sums.eval())
      self.assertEquals((0, 2), pcw_totals_indices.eval().shape)
      self.assertAllEqual([], pcw_totals_sums.eval())
      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
    def testSimpleWeighted(self):
        with self.test_session():
            input_weights = [1.5, 2.0, 3.0, 4.0]
            (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
             pcw_totals_indices, pcw_totals_sums, _,
             leaves) = (tensor_forest_ops.count_extremely_random_stats(
                 self.input_data, [], [], [],
                 self.data_spec,
                 self.input_labels,
                 input_weights,
                 self.tree,
                 self.tree_thresholds,
                 self.node_map,
                 self.split_features,
                 self.split_thresholds,
                 self.epochs,
                 self.current_epoch,
                 num_classes=5,
                 regression=False))

            self.assertAllEqual([[10.5, 1.5, 2., 3., 4.],
                                 [3.5, 1.5, 2., 0., 0.], [7., 0., 0., 3., 4.]],
                                pcw_node_sums.eval())
            self.assertAllEqual([[0, 0, 0], [0, 0, 1]],
                                pcw_splits_indices.eval())
            self.assertAllEqual([1.5, 1.5], pcw_splits_sums.eval())
            self.assertAllEqual([[0, 2], [0, 0], [0, 1]],
                                pcw_totals_indices.eval())
            self.assertAllEqual([2., 3.5, 1.5], pcw_totals_sums.eval())
            self.assertAllEqual([1, 1, 2, 2], leaves.eval())
    def testMissingLabel(self):
        labels = [0, 1, -1, 3]
        with self.test_session():
            (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
             pcw_totals_indices, pcw_totals_sums, _,
             leaves) = (tensor_forest_ops.count_extremely_random_stats(
                 self.input_data, [], [], [],
                 self.data_spec,
                 labels, [],
                 self.tree,
                 self.tree_thresholds,
                 self.node_map,
                 self.split_features,
                 self.split_thresholds,
                 self.epochs,
                 self.current_epoch,
                 num_classes=5,
                 regression=False))

            self.assertAllEqual([[3., 1., 1., 0., 1.], [2., 1., 1., 0., 0.],
                                 [1., 0., 0., 0., 1.]], pcw_node_sums.eval())
            self.assertAllEqual([[0, 0, 0], [0, 0, 1]],
                                pcw_splits_indices.eval())
            self.assertAllEqual([1., 1.], pcw_splits_sums.eval())
            self.assertAllEqual([[0, 2], [0, 0], [0, 1]],
                                pcw_totals_indices.eval())
            self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
            self.assertAllEqual([1, 1, 2, 2], leaves.eval())
    def testSimple(self):
        with self.test_session():
            (pcw_node_sums, pcw_node_squares, pcw_splits_indices,
             pcw_splits_sums, pcw_splits_squares, pcw_totals_indices,
             pcw_totals_sums, pcw_totals_squares,
             leaves) = (tensor_forest_ops.count_extremely_random_stats(
                 self.input_data, [], [], [],
                 self.data_spec,
                 self.input_labels, [],
                 self.tree,
                 self.tree_thresholds,
                 self.node_map,
                 self.split_features,
                 self.split_thresholds,
                 self.epochs,
                 self.current_epoch,
                 num_classes=2,
                 regression=True))

            self.assertAllEqual([[4., 14.], [2., 9.], [2., 5.]],
                                pcw_node_sums.eval())
            self.assertAllEqual([[4., 58.], [2., 45.], [2., 13.]],
                                pcw_node_squares.eval())
            self.assertAllEqual([[0, 0]], pcw_splits_indices.eval())
            self.assertAllEqual([[1., 3.]], pcw_splits_sums.eval())
            self.assertAllEqual([[1., 9.]], pcw_splits_squares.eval())
            self.assertAllEqual([[0]], pcw_totals_indices.eval())
            self.assertAllEqual([[2., 9.]], pcw_totals_sums.eval())
            self.assertAllEqual([[2., 45.]], pcw_totals_squares.eval())
            self.assertAllEqual([1, 1, 2, 2], leaves.eval())
    def testThreaded(self):
        with self.test_session(config=tf.ConfigProto(
                intra_op_parallelism_threads=2)):
            (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
             pcw_totals_indices, pcw_totals_sums, _,
             leaves) = (tensor_forest_ops.count_extremely_random_stats(
                 self.input_data, [], [], [],
                 self.data_spec,
                 self.input_labels, [],
                 self.tree,
                 self.tree_thresholds,
                 self.node_map,
                 self.split_features,
                 self.split_thresholds,
                 self.epochs,
                 self.current_epoch,
                 num_classes=5,
                 regression=False))

            self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.],
                                 [2., 0., 0., 1., 1.]], pcw_node_sums.eval())
            self.assertAllEqual([[0, 0, 0], [0, 0, 1]],
                                pcw_splits_indices.eval())
            self.assertAllEqual([1., 1.], pcw_splits_sums.eval())
            self.assertAllEqual([[0, 2], [0, 0], [0, 1]],
                                pcw_totals_indices.eval())
            self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
            self.assertAllEqual([1, 1, 2, 2], leaves.eval())
    def testSparseInput(self):
        sparse_shape = [4, 10]
        sparse_indices = [[0, 0], [0, 4], [0, 9], [1, 0], [1, 7], [2, 0],
                          [3, 1], [3, 4]]
        sparse_values = [3.0, -1.0, 0.5, 1.5, 6.0, -2.0, -0.5, 2.0]
        with self.test_session():
            (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
             pcw_totals_indices, pcw_totals_sums, _,
             leaves) = (tensor_forest_ops.count_extremely_random_stats(
                 [],
                 sparse_indices,
                 sparse_values,
                 sparse_shape,
                 self.data_spec,
                 self.input_labels, [],
                 self.tree,
                 self.tree_thresholds,
                 self.node_map,
                 self.split_features,
                 self.split_thresholds,
                 self.epochs,
                 self.current_epoch,
                 num_classes=5,
                 regression=False))

            self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 0., 0., 1., 1.],
                                 [2., 1., 1., 0., 0.]], pcw_node_sums.eval())
            self.assertAllEqual([[0, 0, 4], [0, 0, 0], [0, 0, 3]],
                                pcw_splits_indices.eval())
            self.assertAllEqual([1., 2., 1.], pcw_splits_sums.eval())
            self.assertAllEqual([[0, 4], [0, 0], [0, 3]],
                                pcw_totals_indices.eval())
            self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
            self.assertAllEqual([2, 2, 1, 1], leaves.eval())
  def testSimple(self):
    with self.test_session():
      (pcw_node_sums, pcw_node_squares, pcw_splits_indices, pcw_splits_sums,
       pcw_splits_squares, pcw_totals_indices, pcw_totals_sums,
       pcw_totals_squares,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           self.input_data, [], [], [],
           self.input_labels, [],
           self.tree,
           self.tree_thresholds,
           self.node_map,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           self.current_epoch,
           input_spec=self.data_spec,
           num_classes=2,
           regression=True))

      self.assertAllEqual([[4., 14.], [2., 9.], [2., 5.]], pcw_node_sums.eval())
      self.assertAllEqual([[4., 58.], [2., 45.], [2., 13.]],
                          pcw_node_squares.eval())
      self.assertAllEqual([[0, 0]], pcw_splits_indices.eval())
      self.assertAllEqual([[1., 3.]], pcw_splits_sums.eval())
      self.assertAllEqual([[1., 9.]], pcw_splits_squares.eval())
      self.assertAllEqual([[0]], pcw_totals_indices.eval())
      self.assertAllEqual([[2., 9.]], pcw_totals_sums.eval())
      self.assertAllEqual([[2., 45.]], pcw_totals_squares.eval())
      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
  def testSimpleWeighted(self):
    with self.test_session():
      input_weights = [1.5, 2.0, 3.0, 4.0]
      (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
       pcw_totals_indices, pcw_totals_sums, _,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           self.input_data, [], [], [],
           self.input_labels,
           input_weights,
           self.tree,
           self.tree_thresholds,
           self.node_map,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           self.current_epoch,
           input_spec=self.data_spec,
           num_classes=5,
           regression=False))

      self.assertAllEqual([[10.5, 1.5, 2., 3., 4.], [3.5, 1.5, 2., 0., 0.],
                           [7., 0., 0., 3., 4.]], pcw_node_sums.eval())
      self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval())
      self.assertAllEqual([1.5, 1.5], pcw_splits_sums.eval())
      self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval())
      self.assertAllEqual([2., 3.5, 1.5], pcw_totals_sums.eval())
      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
  def testThreaded(self):
    with self.test_session(
        config=config_pb2.ConfigProto(intra_op_parallelism_threads=2)):
      (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
       pcw_totals_indices, pcw_totals_sums, _,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           self.input_data, [], [], [],
           self.input_labels, [],
           self.tree,
           self.tree_thresholds,
           self.node_map,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           self.current_epoch,
           input_spec=self.data_spec,
           num_classes=5,
           regression=False))

      self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.],
                           [2., 0., 0., 1., 1.]], pcw_node_sums.eval())
      self.assertAllEqual([[0, 0, 0], [0, 0, 1]], pcw_splits_indices.eval())
      self.assertAllEqual([1., 1.], pcw_splits_sums.eval())
      self.assertAllEqual([[0, 2], [0, 0], [0, 1]], pcw_totals_indices.eval())
      self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
Пример #13
0
  def testSparseInput(self):
    sparse_shape = [4, 10]
    sparse_indices = [[0, 0], [0, 4], [0, 9], [1, 1], [1, 7], [2, 0], [3, 0],
                      [3, 4]]
    sparse_values = [3.0, -1.0, 0.5, -1.5, 6.0, -2.0, -0.5, 2.0]
    spec_proto = data_ops.TensorForestDataSpec()
    f1 = spec_proto.sparse.add()
    f1.name = 'f1'
    f1.original_type = data_ops.DATA_FLOAT
    f1.size = -1

    spec_proto.dense_features_size = 0
    data_spec = spec_proto.SerializeToString()

    with self.test_session():
      (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
       pcw_totals_indices, pcw_totals_sums, _,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           [],
           sparse_indices,
           sparse_values,
           sparse_shape,
           self.input_labels, [],
           self.tree,
           self.tree_thresholds,
           self.node_map,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           self.current_epoch,
           input_spec=data_spec,
           num_classes=5,
           regression=False))

      self.assertAllEqual([[4., 1., 1., 1., 1.],
                           [2., 0., 0., 1., 1.],
                           [2., 1., 1., 0., 0.]],
                          pcw_node_sums.eval())
      self.assertAllEqual([[0, 0, 4],
                           [0, 0, 0],
                           [0, 0, 3]],
                          pcw_splits_indices.eval())
      self.assertAllEqual([1., 2., 1.], pcw_splits_sums.eval())
      self.assertAllEqual([[0, 4], [0, 0], [0, 3]], pcw_totals_indices.eval())
      self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
      self.assertAllEqual([2, 2, 1, 1], leaves.eval())
  def testSparseInput(self):
    sparse_shape = [4, 10]
    sparse_indices = [[0, 0], [0, 4], [0, 9],
                      [1, 0], [1, 7],
                      [2, 0],
                      [3, 1], [3, 4]]
    sparse_values = [3.0, -1.0, 0.5,
                     1.5, 6.0,
                     -2.0,
                     -0.5, 2.0]
    with self.test_session():
      (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
       pcw_totals_indices, pcw_totals_sums, _,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           [],
           sparse_indices,
           sparse_values,
           sparse_shape,
           self.data_spec,
           self.input_labels, [],
           self.tree,
           self.tree_thresholds,
           self.node_map,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           self.current_epoch,
           num_classes=5,
           regression=False))

      self.assertAllEqual([[4., 1., 1., 1., 1.],
                           [2., 0., 0., 1., 1.],
                           [2., 1., 1., 0., 0.]],
                          pcw_node_sums.eval())
      self.assertAllEqual([[0, 0, 4],
                           [0, 0, 0],
                           [0, 0, 3]],
                          pcw_splits_indices.eval())
      self.assertAllEqual([1., 2., 1.], pcw_splits_sums.eval())
      self.assertAllEqual([[0, 4], [0, 0], [0, 3]], pcw_totals_indices.eval())
      self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval())
      self.assertAllEqual([2, 2, 1, 1], leaves.eval())
    def testFutureEpoch(self):
        current_epoch = [3]
        with self.test_session():
            (pcw_node_sums, _, _, pcw_splits_sums, _, _, pcw_totals_sums, _,
             leaves) = (tensor_forest_ops.count_extremely_random_stats(
                 self.input_data, [], [], [],
                 self.data_spec,
                 self.input_labels, [],
                 self.tree,
                 self.tree_thresholds,
                 self.node_map,
                 self.split_features,
                 self.split_thresholds,
                 self.epochs,
                 current_epoch,
                 num_classes=5,
                 regression=False))

            self.assertAllEqual([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.],
                                 [0., 0., 0., 0., 0.]], pcw_node_sums.eval())
            self.assertAllEqual([], pcw_splits_sums.eval())
            self.assertAllEqual([], pcw_totals_sums.eval())
            self.assertAllEqual([1, 1, 2, 2], leaves.eval())
Пример #16
0
  def testBadInput(self):
    del self.node_map[-1]

    with self.test_session():
      with self.assertRaisesOpError(
          'Number of nodes should be the same in '
          'tree, tree_thresholds, node_to_accumulator, and birth_epoch.'):
        pcw_node, _, _, _, _, _, _, _, _ = (
            tensor_forest_ops.count_extremely_random_stats(
                self.input_data, [], [], [],
                self.input_labels, [],
                self.tree,
                self.tree_thresholds,
                self.node_map,
                self.split_features,
                self.split_thresholds,
                self.epochs,
                self.current_epoch,
                input_spec=self.data_spec,
                num_classes=5,
                regression=False))

        self.assertAllEqual([], pcw_node.eval())
    def testNoAccumulators(self):
        with self.test_session():
            (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _,
             pcw_totals_indices, pcw_totals_sums, _,
             leaves) = (tensor_forest_ops.count_extremely_random_stats(
                 self.input_data, [], [], [],
                 self.data_spec,
                 self.input_labels, [],
                 self.tree,
                 self.tree_thresholds, [-1] * 3,
                 self.split_features,
                 self.split_thresholds,
                 self.epochs,
                 self.current_epoch,
                 num_classes=5,
                 regression=False))

            self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.],
                                 [2., 0., 0., 1., 1.]], pcw_node_sums.eval())
            self.assertEquals((0, 3), pcw_splits_indices.eval().shape)
            self.assertAllEqual([], pcw_splits_sums.eval())
            self.assertEquals((0, 2), pcw_totals_indices.eval().shape)
            self.assertAllEqual([], pcw_totals_sums.eval())
            self.assertAllEqual([1, 1, 2, 2], leaves.eval())
  def testFutureEpoch(self):
    current_epoch = [3]
    with self.test_session():
      (pcw_node_sums, _, _, pcw_splits_sums, _, _, pcw_totals_sums, _,
       leaves) = (tensor_forest_ops.count_extremely_random_stats(
           self.input_data, [], [], [],
           self.input_labels, [],
           self.tree,
           self.tree_thresholds,
           self.node_map,
           self.split_features,
           self.split_thresholds,
           self.epochs,
           current_epoch,
           input_spec=self.data_spec,
           num_classes=5,
           regression=False))

      self.assertAllEqual(
          [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]],
          pcw_node_sums.eval())
      self.assertAllEqual([], pcw_splits_sums.eval())
      self.assertAllEqual([], pcw_totals_sums.eval())
      self.assertAllEqual([1, 1, 2, 2], leaves.eval())
Пример #19
0
  def training_graph(self,
                     input_data,
                     input_labels,
                     random_seed,
                     data_spec,
                     sparse_features=None,
                     input_weights=None):

    """Constructs a TF graph for training a random tree.

    Args:
      input_data: A tensor or placeholder for input data.
      input_labels: A tensor or placeholder for labels associated with
        input_data.
      random_seed: The random number generator seed to use for this tree.  0
        means use the current time as the seed.
      data_spec: A data_ops.TensorForestDataSpec object specifying the
        original feature/columns of the data.
      sparse_features: A tf.SparseTensor for sparse input data.
      input_weights: A float tensor or placeholder holding per-input weights,
        or None if all inputs are to be weighted equally.

    Returns:
      The last op in the random tree training graph.
    """
    epoch = math_ops.to_int32(get_epoch_variable())

    serialized_input_spec = data_spec.SerializeToString()

    if input_weights is None:
      input_weights = []

    if input_data is None:
      input_data = []

    sparse_indices = []
    sparse_values = []
    sparse_shape = []
    if sparse_features is not None:
      sparse_indices = sparse_features.indices
      sparse_values = sparse_features.values
      sparse_shape = sparse_features.dense_shape

    # Count extremely random stats.
    (node_sums, node_squares, splits_indices, splits_sums, splits_squares,
     totals_indices, totals_sums, totals_squares,
     input_leaves) = (tensor_forest_ops.count_extremely_random_stats(
         input_data,
         sparse_indices,
         sparse_values,
         sparse_shape,
         input_labels,
         input_weights,
         self.variables.tree,
         self.variables.tree_thresholds,
         self.variables.node_to_accumulator_map,
         self.variables.candidate_split_features,
         self.variables.candidate_split_thresholds,
         self.variables.start_epoch,
         epoch,
         input_spec=serialized_input_spec,
         num_classes=self.params.num_output_columns,
         regression=self.params.regression))
    node_update_ops = []
    node_update_ops.append(
        state_ops.assign_add(self.variables.node_sums, node_sums))

    splits_update_ops = []
    splits_update_ops.append(
        tensor_forest_ops.scatter_add_ndim(self.variables.candidate_split_sums,
                                           splits_indices, splits_sums))
    splits_update_ops.append(
        tensor_forest_ops.scatter_add_ndim(self.variables.accumulator_sums,
                                           totals_indices, totals_sums))

    if self.params.regression:
      node_update_ops.append(state_ops.assign_add(self.variables.node_squares,
                                                  node_squares))
      splits_update_ops.append(
          tensor_forest_ops.scatter_add_ndim(
              self.variables.candidate_split_squares, splits_indices,
              splits_squares))
      splits_update_ops.append(
          tensor_forest_ops.scatter_add_ndim(self.variables.accumulator_squares,
                                             totals_indices, totals_squares))

    # Sample inputs.
    update_indices, feature_updates, threshold_updates = (
        tensor_forest_ops.sample_inputs(
            input_data,
            sparse_indices,
            sparse_values,
            sparse_shape,
            input_weights,
            self.variables.node_to_accumulator_map,
            input_leaves,
            self.variables.candidate_split_features,
            self.variables.candidate_split_thresholds,
            input_spec=serialized_input_spec,
            split_initializations_per_input=(
                self.params.split_initializations_per_input),
            split_sampling_random_seed=random_seed))
    update_features_op = state_ops.scatter_update(
        self.variables.candidate_split_features, update_indices,
        feature_updates)
    update_thresholds_op = state_ops.scatter_update(
        self.variables.candidate_split_thresholds, update_indices,
        threshold_updates)

    # Calculate finished nodes.
    with ops.control_dependencies(splits_update_ops):
      # Passing input_leaves to finished nodes here means that nodes that
      # have become stale won't be deallocated until an input reaches them,
      # because we're trying to avoid considering every fertile node for
      # performance reasons.
      finished, stale = tensor_forest_ops.finished_nodes(
          input_leaves,
          self.variables.node_to_accumulator_map,
          self.variables.candidate_split_sums,
          self.variables.candidate_split_squares,
          self.variables.accumulator_sums,
          self.variables.accumulator_squares,
          self.variables.start_epoch,
          epoch,
          num_split_after_samples=self.params.split_after_samples,
          min_split_samples=self.params.min_split_samples,
          dominate_method=self.params.dominate_method,
          dominate_fraction=self.params.dominate_fraction)

    # Update leaf scores.
    # TODO(thomaswc): Store the leaf scores in a TopN and only update the
    # scores of the leaves that were touched by this batch of input.
    children = array_ops.squeeze(
        array_ops.slice(self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
    leaves = math_ops.to_int32(
        array_ops.squeeze(
            array_ops.where(is_leaf), squeeze_dims=[1]))
    non_fertile_leaves = array_ops.boolean_mask(
        leaves, math_ops.less(array_ops.gather(
            self.variables.node_to_accumulator_map, leaves), 0))

    # TODO(gilberth): It should be possible to limit the number of non
    # fertile leaves we calculate scores for, especially since we can only take
    # at most array_ops.shape(finished)[0] of them.
    with ops.control_dependencies(node_update_ops):
      sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves)
      if self.params.regression:
        squares = array_ops.gather(self.variables.node_squares,
                                   non_fertile_leaves)
        non_fertile_leaf_scores = self._variance(sums, squares)
      else:
        non_fertile_leaf_scores = self._weighted_gini(sums)

    # Calculate best splits.
    with ops.control_dependencies(splits_update_ops):
      split_indices = tensor_forest_ops.best_splits(
          finished,
          self.variables.node_to_accumulator_map,
          self.variables.candidate_split_sums,
          self.variables.candidate_split_squares,
          self.variables.accumulator_sums,
          self.variables.accumulator_squares,
          regression=self.params.regression)

    # Grow tree.
    with ops.control_dependencies([update_features_op, update_thresholds_op,
                                   non_fertile_leaves.op]):
      (tree_update_indices, tree_children_updates, tree_threshold_updates,
       new_eot) = (tensor_forest_ops.grow_tree(
           self.variables.end_of_tree, self.variables.node_to_accumulator_map,
           finished, split_indices, self.variables.candidate_split_features,
           self.variables.candidate_split_thresholds))
      tree_update_op = state_ops.scatter_update(
          self.variables.tree, tree_update_indices, tree_children_updates)
      thresholds_update_op = state_ops.scatter_update(
          self.variables.tree_thresholds, tree_update_indices,
          tree_threshold_updates)
      # TODO(thomaswc): Only update the epoch on the new leaves.
      new_epoch_updates = epoch * array_ops.ones_like(tree_threshold_updates,
                                                      dtype=dtypes.int32)
      epoch_update_op = state_ops.scatter_update(
          self.variables.start_epoch, tree_update_indices,
          new_epoch_updates)

    # Update fertile slots.
    with ops.control_dependencies([tree_update_op]):
      (n2a_map_updates, a2n_map_updates, accumulators_cleared,
       accumulators_allocated) = (tensor_forest_ops.update_fertile_slots(
           finished,
           non_fertile_leaves,
           non_fertile_leaf_scores,
           self.variables.end_of_tree,
           self.variables.accumulator_sums,
           self.variables.node_to_accumulator_map,
           stale,
           self.variables.node_sums,
           regression=self.params.regression))

    # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has
    # used it to calculate new leaves.
    with ops.control_dependencies([n2a_map_updates.op]):
      eot_update_op = state_ops.assign(self.variables.end_of_tree, new_eot)

    updates = []
    updates.append(eot_update_op)
    updates.append(tree_update_op)
    updates.append(thresholds_update_op)
    updates.append(epoch_update_op)

    updates.append(
        state_ops.scatter_update(self.variables.node_to_accumulator_map,
                                 n2a_map_updates[0], n2a_map_updates[1]))

    updates.append(
        state_ops.scatter_update(self.variables.accumulator_to_node_map,
                                 a2n_map_updates[0], a2n_map_updates[1]))

    cleared_and_allocated_accumulators = array_ops.concat(
        [accumulators_cleared, accumulators_allocated], 0)

    # Calculate values to put into scatter update for candidate counts.
    # Candidate split counts are always reset back to 0 for both cleared
    # and allocated accumulators. This means some accumulators might be doubly
    # reset to 0 if the were released and not allocated, then later allocated.
    split_values = array_ops.tile(
        array_ops.expand_dims(array_ops.expand_dims(
            array_ops.zeros_like(cleared_and_allocated_accumulators,
                                 dtype=dtypes.float32), 1), 2),
        [1, self.params.num_splits_to_consider, self.params.num_output_columns])
    updates.append(state_ops.scatter_update(
        self.variables.candidate_split_sums,
        cleared_and_allocated_accumulators, split_values))
    if self.params.regression:
      updates.append(state_ops.scatter_update(
          self.variables.candidate_split_squares,
          cleared_and_allocated_accumulators, split_values))

    # Calculate values to put into scatter update for total counts.
    total_cleared = array_ops.tile(
        array_ops.expand_dims(
            math_ops.negative(array_ops.ones_like(accumulators_cleared,
                                                  dtype=dtypes.float32)), 1),
        [1, self.params.num_output_columns])
    total_reset = array_ops.tile(
        array_ops.expand_dims(
            array_ops.zeros_like(accumulators_allocated,
                                 dtype=dtypes.float32), 1),
        [1, self.params.num_output_columns])
    accumulator_updates = array_ops.concat([total_cleared, total_reset], 0)
    updates.append(state_ops.scatter_update(
        self.variables.accumulator_sums,
        cleared_and_allocated_accumulators, accumulator_updates))
    if self.params.regression:
      updates.append(state_ops.scatter_update(
          self.variables.accumulator_squares,
          cleared_and_allocated_accumulators, accumulator_updates))

    # Calculate values to put into scatter update for candidate splits.
    split_features_updates = array_ops.tile(
        array_ops.expand_dims(
            math_ops.negative(array_ops.ones_like(
                cleared_and_allocated_accumulators)), 1),
        [1, self.params.num_splits_to_consider])
    updates.append(state_ops.scatter_update(
        self.variables.candidate_split_features,
        cleared_and_allocated_accumulators, split_features_updates))

    updates += self.finish_iteration()

    return control_flow_ops.group(*updates)