Beispiel #1
0
  def testMultidimensionalAcculumator(self):
    with self.test_session() as sess:
      accumulator = stats_accumulator_ops.StatsAccumulator(
          stamp_token=0,
          gradient_shape=tensor_shape.scalar(),
          hessian_shape=tensor_shape.scalar())
      with ops.control_dependencies([accumulator._create_op]):
        op1 = accumulator.add(
            stamp_token=0,
            partition_ids=[1, 2, 1],
            feature_ids=[[2, 2], [3, 0], [2, 2]],
            gradients=[0.1, 0.3, 0.8],
            hessians=[0.2, 0.4, -9])
        op2 = accumulator.add(0, [2, 1], [[3, 1], [2, 2]], [0.1, 1], [0.2, -1])

      with ops.control_dependencies([op1, op2]):
        num_updates, partition, bucket_ids, grads, hessians = accumulator.flush(
            stamp_token=0, next_stamp_token=1)
        num_updates, partition, bucket_ids, grads, hessians = sess.run(
            [num_updates, partition, bucket_ids, grads, hessians])

      result = _AccumulatorResultToDict(partition, bucket_ids, grads, hessians)
      self.assertEqual(num_updates, 2)
      self.assertEqual(len(result), 3)
      # Key is partion, bucket, dimension.
      self.assertAllClose(result[(1, 2, 2)], [1.9, -9.8])
      self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4])
      self.assertAllClose(result[(2, 3, 1)], [0.1, 0.2])
Beispiel #2
0
  def testDropStaleUpdate(self):
    with self.test_session() as sess:
      accumulator = stats_accumulator_ops.StatsAccumulator(
          stamp_token=0,
          gradient_shape=tensor_shape.scalar(),
          hessian_shape=tensor_shape.scalar())
      with ops.control_dependencies([accumulator._create_op]):
        op1 = accumulator.add(
            stamp_token=0,
            partition_ids=[1, 2],
            feature_ids=[[2, 0], [3, 0]],
            gradients=[0.1, 0.3],
            hessians=[0.2, 0.4])
        op2 = accumulator.add(
            stamp_token=-1,
            partition_ids=[1],
            feature_ids=[[2, 0]],
            gradients=[0.1],
            hessians=[0.2])

      with ops.control_dependencies([op1, op2]):
        num_updates, partition, feature, grads, hessians = accumulator.flush(
            stamp_token=0, next_stamp_token=1)
        num_updates, partition, feature, grads, hessians = sess.run(
            [num_updates, partition, feature, grads, hessians])

      result = _AccumulatorResultToDict(partition, feature, grads, hessians)
      self.assertEqual(num_updates, 1)
      self.assertEqual(len(result), 2)
      self.assertAllClose(result[(1, 2, 0)], [0.1, 0.2])
      self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4])
    def testDeserialize(self):
        with self.test_session() as sess:
            accumulator = stats_accumulator_ops.StatsAccumulator(
                stamp_token=0,
                gradient_shape=tensor_shape.scalar(),
                hessian_shape=tensor_shape.scalar())
            with ops.control_dependencies([accumulator._create_op]):
                # These will be deleted due to deserialize call.
                op1 = accumulator.add(stamp_token=0,
                                      partition_ids=[1, 2],
                                      feature_ids=[2, 3],
                                      gradients=[0.1, 0.3],
                                      hessians=[0.2, 0.4])

            with ops.control_dependencies([op1]):
                deserialize = (accumulator.deserialize(stamp_token=2,
                                                       num_updates=3,
                                                       partition_ids=[3, 4],
                                                       feature_ids=[5, 6],
                                                       gradients=[0.4, 0.5],
                                                       hessians=[0.6, 0.7]))
            with ops.control_dependencies([deserialize]):
                num_updates, partition, feature, grads, hessians = accumulator.flush(
                    stamp_token=2, next_stamp_token=3)
                num_updates, partition, feature, grads, hessians = sess.run(
                    [num_updates, partition, feature, grads, hessians])

            result = _AccumulatorResultToDict(partition, feature, grads,
                                              hessians)
            self.assertEqual(num_updates, 3)
            self.assertEqual(len(result), 2)
            self.assertAllClose(result[(3, 5)], [0.4, 0.6])
            self.assertAllClose(result[(4, 6)], [0.5, 0.7])
    def testMakeSummary(self):
        with self.test_session() as sess:
            accumulator = stats_accumulator_ops.StatsAccumulator(
                stamp_token=0,
                gradient_shape=tensor_shape.TensorShape([2]),
                hessian_shape=tensor_shape.TensorShape([2, 2]))
            partition, feature, grads, hessians = accumulator._make_summary(
                partition_ids=[1, 2, 1],
                feature_ids=[2, 3, 2],
                # Two values for gradients,
                gradients=[[0.1, 0.1], [0.2, 0.2], [0.10, 0.11]],
                # A 2x2 matrix for each hessian.
                hessians=[[[0.01, 0.02], [0.03, 0.04]],
                          [[0.05, 0.06], [0.07, 0.08]],
                          [[0.011, 0.022], [0.033, 0.044]]])
            partition, feature, grads, hessians = sess.run(
                [partition, feature, grads, hessians])

            result = _AccumulatorResultToDict(partition, feature, grads,
                                              hessians)
            self.assertEqual(len(result), 2)
            self.assertAllClose(result[(1, 2)][0], [0.20, 0.21])
            self.assertAllClose(result[(1, 2)][1],
                                [[0.021, 0.042], [0.063, 0.084]])
            self.assertAllClose(result[(2, 3)][0], [0.2, 0.2])
            self.assertAllClose(result[(2, 3)][1],
                                [[0.05, 0.06], [0.07, 0.08]])
Beispiel #5
0
  def testSimpleAcculumator(self):
    with self.cached_session() as sess:
      accumulator = stats_accumulator_ops.StatsAccumulator(
          stamp_token=0,
          gradient_shape=tensor_shape.TensorShape([]),
          hessian_shape=tensor_shape.TensorShape([]))
      with ops.control_dependencies([accumulator.initializer]):
        op1 = accumulator.add(
            stamp_token=0,
            partition_ids=[1, 2],
            feature_ids=[[2, 0], [3, 0]],
            gradients=[0.1, 0.3],
            hessians=[0.2, 0.4])
        op2 = accumulator.add(0, [1], [[2, 0]], [0.1], [0.2])

      with ops.control_dependencies([op1, op2]):
        num_updates, partition, bucket_ids, grads, hessians = accumulator.flush(
            stamp_token=0, next_stamp_token=1)
        num_updates, partition, bucket_ids, grads, hessians = sess.run(
            [num_updates, partition, bucket_ids, grads, hessians])

      result = _AccumulatorResultToDict(partition, bucket_ids, grads, hessians)
      self.assertEqual(num_updates, 2)
      self.assertEqual(len(result), 2)
      # Key is partition, bucket, dimension
      self.assertAllClose(result[(1, 2, 0)], [0.2, 0.4])
      self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4])
Beispiel #6
0
  def __init__(self,
               l1_regularization,
               l2_regularization,
               tree_complexity_regularization,
               min_node_weight,
               feature_column_group_id,
               epsilon,
               num_quantiles,
               gradient_shape,
               hessian_shape,
               multiclass_strategy,
               init_stamp_token=0,
               loss_uses_sum_reduction=False,
               name=None):
    """Initialize the internal state for this split handler.

    Args:
      l1_regularization: L1 regularization applied for this split handler.
      l2_regularization: L2 regularization applied for this split handler.
      tree_complexity_regularization: Tree complexity regularization applied
          for this split handler.
      min_node_weight: Minimum sum of weights of examples in each partition to
          be considered for splitting.
      feature_column_group_id: Feature column group index.
      epsilon: A float, the error bound for quantile computation.
      num_quantiles: An int, the number of buckets to create from the histogram.
      gradient_shape: A TensorShape, containing shape of gradients.
      hessian_shape: A TensorShape, containing shape of hessians.
      multiclass_strategy: Strategy describing how to treat multiclass problems.
      init_stamp_token: A tensor containing an scalar for initial stamp of the
         stamped objects.
      loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
          SUM or MEAN reduction was used for the loss.
      name: An optional handler name.
    """
    super(InequalitySplitHandler, self).__init__(
        name=name,
        l1_regularization=l1_regularization,
        l2_regularization=l2_regularization,
        tree_complexity_regularization=tree_complexity_regularization,
        min_node_weight=min_node_weight,
        feature_column_group_id=feature_column_group_id,
        gradient_shape=gradient_shape,
        hessian_shape=hessian_shape,
        multiclass_strategy=multiclass_strategy,
        loss_uses_sum_reduction=loss_uses_sum_reduction)
    self._stats_accumulator = stats_accumulator_ops.StatsAccumulator(
        init_stamp_token,
        gradient_shape,
        hessian_shape,
        name="StatsAccumulator/{}".format(self._name))
    # Allocate both stats accumulator and quantile accumulator on the same
    # device so that we can build splits with fewer RPCs.
    with ops.colocate_with(self._stats_accumulator.resource()):
      self._quantile_accumulator = quantile_ops.QuantileAccumulator(
          init_stamp_token,
          epsilon=epsilon,
          num_quantiles=num_quantiles,
          name="QuantileAccumulator/{}".format(self._name))
Beispiel #7
0
    def __init__(
            self,
            sparse_int_column,
            l1_regularization,
            l2_regularization,
            tree_complexity_regularization,
            min_node_weight,
            feature_column_group_id,
            gradient_shape,
            hessian_shape,
            multiclass_strategy,
            init_stamp_token=0,
            loss_uses_sum_reduction=False,
            weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
            name=None):
        """Initialize the internal state for this split handler.

    Args:
      sparse_int_column: A `SparseTensor` column with int64 values associated
        with this handler.
      l1_regularization: L1 regularization applied for this split handler.
      l2_regularization: L2 regularization applied for this split handler.
      tree_complexity_regularization: Tree complexity regularization applied
          for this split handler.
      min_node_weight: Minimum sum of weights of examples in each partition to
          be considered for splitting.
      feature_column_group_id: Feature column group index.
      gradient_shape: A TensorShape, containing shape of gradients.
      hessian_shape: A TensorShape, containing shape of hessians.
      multiclass_strategy: Strategy describing how to treat multiclass problems.
      init_stamp_token: A tensor containing an scalar for initial stamp of the
         stamped objects.
      loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
          SUM or MEAN reduction was used for the loss.
      weak_learner_type: Specifies the type of weak learner to use.
      name: An optional handler name.
    """
        super(EqualitySplitHandler, self).__init__(
            l1_regularization=l1_regularization,
            l2_regularization=l2_regularization,
            tree_complexity_regularization=tree_complexity_regularization,
            min_node_weight=min_node_weight,
            feature_column_group_id=feature_column_group_id,
            gradient_shape=gradient_shape,
            hessian_shape=hessian_shape,
            multiclass_strategy=multiclass_strategy,
            loss_uses_sum_reduction=loss_uses_sum_reduction,
            name=name)
        self._stats_accumulator = stats_accumulator_ops.StatsAccumulator(
            init_stamp_token,
            gradient_shape,
            hessian_shape,
            name="StatsAccumulator/{}".format(self._name))
        self._sparse_int_column = sparse_int_column
        self._weak_learner_type = weak_learner_type
    def testSerialize(self):
        with self.test_session() as sess:
            accumulator = stats_accumulator_ops.StatsAccumulator(
                stamp_token=0,
                gradient_shape=tensor_shape.TensorShape([2]),
                hessian_shape=tensor_shape.TensorShape([2, 2]))
            with ops.control_dependencies([accumulator._create_op]):
                op1 = accumulator.add(
                    stamp_token=0,
                    partition_ids=[1, 2],
                    feature_ids=[2, 3],
                    # Two values for gradients,
                    gradients=[[0.1, 0.1], [0.2, 0.2]],
                    # A 2x2 matrix for each hessian.
                    hessians=[[[0.01, 0.02], [0.03, 0.04]],
                              [[0.05, 0.06], [0.07, 0.08]]])

            with ops.control_dependencies([op1]):
                (stamp_token, num_updates_1, partition_1, feature_1, grads_1,
                 hessians_1) = accumulator.serialize()
            # Make sure that the accumulator hasn't changed during serialization.
            with ops.control_dependencies([stamp_token]):
                num_updates_2, partition_2, feature_2, grads_2, hessians_2 = (
                    accumulator.flush(stamp_token=0, next_stamp_token=1))
                (stamp_token, num_updates_1, partition_1, feature_1, grads_1,
                 hessians_1, num_updates_2, partition_2, feature_2, grads_2,
                 hessians_2) = sess.run([
                     stamp_token, num_updates_1, partition_1, feature_1,
                     grads_1, hessians_1, num_updates_2, partition_2,
                     feature_2, grads_2, hessians_2
                 ])

            result_1 = _AccumulatorResultToDict(partition_1, feature_1,
                                                grads_1, hessians_1)
            result_2 = _AccumulatorResultToDict(partition_2, feature_2,
                                                grads_2, hessians_2)

            self.assertEqual(num_updates_1, 1)
            self.assertEqual(num_updates_2, 1)
            self.assertEqual(len(result_1), 2)
            self.assertAllClose(result_1[(1, 2)][0], [0.1, 0.1])
            self.assertAllClose(result_1[(1, 2)][1],
                                [[0.01, 0.02], [0.03, 0.04]])
            self.assertAllClose(result_1[(2, 3)][0], [0.2, 0.2])
            self.assertAllClose(result_1[(2, 3)][1],
                                [[0.05, 0.06], [0.07, 0.08]])

            self.assertAllEqual(result_1[1, 2][0], result_2[1, 2][0])
            self.assertAllEqual(result_1[1, 2][1], result_2[1, 2][1])
            self.assertAllEqual(result_1[2, 3][0], result_2[2, 3][0])
            self.assertAllEqual(result_1[2, 3][1], result_2[2, 3][1])
    def __init__(self,
                 sparse_int_column,
                 l1_regularization,
                 l2_regularization,
                 tree_complexity_regularization,
                 min_node_weight,
                 feature_column_group_id,
                 gradient_shape,
                 hessian_shape,
                 multiclass_strategy,
                 init_stamp_token=0,
                 name=None):
        """Initialize the internal state for this split handler.

    Args:
      sparse_int_column: A `SparseTensor` column with int64 values associated
        with this handler.
      l1_regularization: L1 regularization applied for this split handler.
      l2_regularization: L2 regularization applied for this split handler.
      tree_complexity_regularization: Tree complexity regularization applied
          for this split handler.
      min_node_weight: Minimum sum of weights of examples in each partition to
          be considered for splitting.
      feature_column_group_id: Feature column group index.
      gradient_shape: A TensorShape, containing shape of gradients.
      hessian_shape: A TensorShape, containing shape of hessians.
      multiclass_strategy: Strategy describing how to treat multiclass problems.
      init_stamp_token: A tensor containing an scalar for initial stamp of the
         stamped objects.
      name: An optional handler name.
    """
        super(EqualitySplitHandler, self).__init__(
            l1_regularization=l1_regularization,
            l2_regularization=l2_regularization,
            tree_complexity_regularization=tree_complexity_regularization,
            min_node_weight=min_node_weight,
            feature_column_group_id=feature_column_group_id,
            gradient_shape=gradient_shape,
            hessian_shape=hessian_shape,
            multiclass_strategy=multiclass_strategy,
            name=name)
        self._stats_accumulator = stats_accumulator_ops.StatsAccumulator(
            init_stamp_token,
            gradient_shape,
            hessian_shape,
            name="StatsAccumulator/{}".format(self._name))
        self._sparse_int_column = sparse_int_column
Beispiel #10
0
 def testMakeSummary(self):
   with self.test_session() as sess:
     accumulator = stats_accumulator_ops.StatsAccumulator(
         stamp_token=0,
         gradient_shape=tensor_shape.scalar(),
         hessian_shape=tensor_shape.scalar())
     partition, feature, grads, hessians = accumulator._make_summary(
         partition_ids=[1, 2, 1],
         feature_ids=[[2, 0], [3, 1], [2, 0]],
         gradients=[0.1, 0.3, 0.1],
         hessians=[0.2, 0.4, 0.2])
     partition, feature, grads, hessians = sess.run(
         [partition, feature, grads, hessians])
     result = _AccumulatorResultToDict(partition, feature, grads, hessians)
     self.assertEqual(len(result), 2)
     self.assertAllClose(result[(1, 2, 0)], [0.2, 0.4])
     self.assertAllClose(result[(2, 3, 1)], [0.3, 0.4])
    def testDeserialize(self):
        with self.test_session() as sess:
            accumulator = stats_accumulator_ops.StatsAccumulator(
                stamp_token=0,
                gradient_shape=tensor_shape.TensorShape([2]),
                hessian_shape=tensor_shape.TensorShape([2, 2]))
            with ops.control_dependencies([accumulator._create_op]):
                # These will be deleted due to deserialize call.
                op1 = accumulator.add(
                    stamp_token=0,
                    partition_ids=[1, 2],
                    feature_ids=[2, 3],
                    # Two values for gradients,
                    gradients=[[0.1, 0.1], [0.2, 0.2]],
                    # A 2x2 matrix for each hessian.
                    hessians=[[[0.01, 0.02], [0.03, 0.04]],
                              [[0.05, 0.06], [0.07, 0.08]]])

            with ops.control_dependencies([op1]):
                deserialize = accumulator.deserialize(
                    stamp_token=2,
                    num_updates=3,
                    partition_ids=[3, 4],
                    feature_ids=[4, 5],
                    # Two values for gradients,
                    gradients=[[0.3, 0.3], [0.5, 0.5]],
                    # A 2x2 matrix for each hessian.
                    hessians=[[[0.03, 0.04], [0.05, 0.06]],
                              [[0.07, 0.08], [0.09, 0.10]]])
            with ops.control_dependencies([deserialize]):
                num_updates, partition, feature, grads, hessians = accumulator.flush(
                    stamp_token=2, next_stamp_token=3)
                num_updates, partition, feature, grads, hessians = sess.run(
                    [num_updates, partition, feature, grads, hessians])

            result = _AccumulatorResultToDict(partition, feature, grads,
                                              hessians)
            self.assertEqual(num_updates, 3)
            self.assertEqual(len(result), 2)
            self.assertAllClose(result[(3, 4)][0], [0.3, 0.3])
            self.assertAllClose(result[(3, 4)][1],
                                [[0.03, 0.04], [0.05, 0.06]])
            self.assertAllClose(result[(4, 5)][0], [0.5, 0.5])
            self.assertAllClose(result[(4, 5)][1],
                                [[0.07, 0.08], [0.09, 0.10]])
Beispiel #12
0
  def testSerialize(self):
    with self.test_session() as sess:
      accumulator = stats_accumulator_ops.StatsAccumulator(
          stamp_token=0,
          gradient_shape=tensor_shape.scalar(),
          hessian_shape=tensor_shape.scalar())
      with ops.control_dependencies([accumulator._create_op]):
        op1 = accumulator.add(
            stamp_token=0,
            partition_ids=[1, 2],
            feature_ids=[[2, 0], [3, 0]],
            gradients=[0.1, 0.3],
            hessians=[0.2, 0.4])

      with ops.control_dependencies([op1]):
        (stamp_token, num_updates, partition_1, feature_1, grads_1,
         hessians_1) = accumulator.serialize()
      # Make sure that the accumulator hasn't changed during serialization.
      with ops.control_dependencies([stamp_token]):
        num_updates_2, partition_2, feature_2, grads_2, hessians_2 = (
            accumulator.flush(stamp_token=0, next_stamp_token=1))
        (stamp_token, num_updates, partition_1, feature_1, grads_1, hessians_1,
         num_updates_2, partition_2, feature_2, grads_2, hessians_2) = sess.run(
             [
                 stamp_token, num_updates, partition_1, feature_1, grads_1,
                 hessians_1, num_updates_2, partition_2, feature_2, grads_2,
                 hessians_2
             ])

      result_1 = _AccumulatorResultToDict(partition_1, feature_1, grads_1,
                                          hessians_1)
      result_2 = _AccumulatorResultToDict(partition_2, feature_2, grads_2,
                                          hessians_2)
      self.assertEqual(num_updates, 1)
      self.assertEqual(num_updates_2, 1)
      self.assertEqual(len(result_1), 2)
      self.assertAllClose(result_1[(1, 2, 0)], [0.1, 0.2])
      self.assertAllClose(result_1[(2, 3, 0)], [0.3, 0.4])
      self.assertAllEqual(result_1, result_2)
      self.assertEqual(0, stamp_token)
    def testDropStaleUpdate(self):
        with self.test_session() as sess:
            accumulator = stats_accumulator_ops.StatsAccumulator(
                stamp_token=0,
                gradient_shape=tensor_shape.TensorShape([2]),
                hessian_shape=tensor_shape.TensorShape([2, 2]))
            with ops.control_dependencies([accumulator._create_op]):
                op1 = accumulator.add(
                    stamp_token=0,
                    partition_ids=[1, 2],
                    feature_ids=[2, 3],
                    # Two values for gradients,
                    gradients=[[0.1, 0.1], [0.2, 0.2]],
                    # A 2x2 matrix for each hessian.
                    hessians=[[[0.01, 0.02], [0.03, 0.04]],
                              [[0.05, 0.06], [0.07, 0.08]]])
                op2 = accumulator.add(stamp_token=-1,
                                      partition_ids=[1],
                                      feature_ids=[2],
                                      gradients=[[0.10, 0.11]],
                                      hessians=[[[0.011, 0.022],
                                                 [0.033, 0.044]]])

            with ops.control_dependencies([op1, op2]):
                num_updates, partition, feature, grads, hessians = accumulator.flush(
                    stamp_token=0, next_stamp_token=1)
                num_updates, partition, feature, grads, hessians = sess.run(
                    [num_updates, partition, feature, grads, hessians])

            result = _AccumulatorResultToDict(partition, feature, grads,
                                              hessians)
            self.assertEqual(num_updates, 1)
            self.assertEqual(len(result), 2)
            self.assertAllClose(result[(1, 2)][0], [0.1, 0.1])
            self.assertAllClose(result[(1, 2)][1],
                                [[0.01, 0.02], [0.03, 0.04]])
            self.assertAllClose(result[(2, 3)][0], [0.2, 0.2])
            self.assertAllClose(result[(2, 3)][1],
                                [[0.05, 0.06], [0.07, 0.08]])
Beispiel #14
0
    def train(self, loss, predictions_dict, labels):
        """Grows a new tree and adds it to the ensemble.

    Args:
      loss: A scalar tensor representing average loss of examples.
      predictions_dict: Dictionary of Rank 2 `Tensor` representing information
          about predictions per example.
      labels: Rank 2 `Tensor` representing labels per example.

    Returns:
      An op that adds a new tree to the ensemble.

    Raises:
      ValueError: if inputs are not valid.
    """
        # Get the worker device from input dependencies.
        input_deps = (self._dense_floats + self._sparse_float_indices +
                      self._sparse_int_indices)
        worker_device = input_deps[0].device

        # Get tensors relevant for training and form the loss.
        predictions = predictions_dict[PREDICTIONS]
        partition_ids = predictions_dict[PARTITION_IDS]
        ensemble_stamp = predictions_dict[ENSEMBLE_STAMP]
        gradients = gradients_impl.gradients(loss,
                                             predictions,
                                             name="Gradients",
                                             colocate_gradients_with_ops=False,
                                             gate_gradients=0,
                                             aggregation_method=None)[0]
        strategy = self._learner_config.multi_class_strategy

        class_id = -1
        # Handle different multiclass strategies.
        if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS:
            # We build one vs rest trees.
            gradient_shape = tensor_shape.scalar()
            hessian_shape = tensor_shape.scalar()

            if self._logits_dimension == 1:
                # We have only 1 score, gradients is of shape [batch, 1].
                hessians = gradients_impl.gradients(
                    gradients,
                    predictions,
                    name="Hessian",
                    colocate_gradients_with_ops=False,
                    gate_gradients=0,
                    aggregation_method=None)[0]

                squeezed_gradients = array_ops.squeeze(gradients, axis=[1])
                squeezed_hessians = array_ops.squeeze(hessians, axis=[1])
            else:
                hessian_list = self._diagonal_hessian(gradients, predictions)
                # Assemble hessian list into a tensor.
                hessians = array_ops.stack(hessian_list, axis=1)

                # Choose the class for which the tree is built (one vs rest).
                class_id = math_ops.to_int32(
                    predictions_dict[NUM_TREES_ATTEMPTED] %
                    self._logits_dimension)

                # Use class id tensor to get the column with that index from gradients
                # and hessians.
                squeezed_gradients = array_ops.squeeze(
                    _get_column_by_index(gradients, class_id))
                squeezed_hessians = array_ops.squeeze(
                    _get_column_by_index(hessians, class_id))
        else:
            # Other multiclass strategies.
            gradient_shape = tensor_shape.TensorShape([self._logits_dimension])

            if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN:
                hessian_shape = tensor_shape.TensorShape(
                    ([self._logits_dimension, self._logits_dimension]))
                hessian_list = self._full_hessian(gradients, predictions)
            else:
                # Diagonal hessian strategy.
                hessian_shape = tensor_shape.TensorShape(
                    ([self._logits_dimension]))
                hessian_list = self._diagonal_hessian(gradients, predictions)

            squeezed_gradients = gradients
            hessians = array_ops.stack(hessian_list, axis=1)
            squeezed_hessians = hessians

        # Get the weights for each example for quantiles calculation,
        weights = self._get_weights(hessian_shape, squeezed_hessians)

        regularization_config = self._learner_config.regularization
        min_node_weight = self._learner_config.constraints.min_node_weight
        # Create all handlers ensuring resources are evenly allocated across PS.
        fc_name_idx = 0
        handlers = []
        init_stamp_token = constant_op.constant(0, dtype=dtypes.int64)
        with ops.device(self._get_replica_device_setter(worker_device)):
            # Create handlers for dense float columns
            for dense_float_column_idx in range(len(self._dense_floats)):
                fc_name = self._fc_names[fc_name_idx]
                handlers.append(
                    ordinal_split_handler.DenseSplitHandler(
                        l1_regularization=regularization_config.l1,
                        l2_regularization=regularization_config.l2,
                        tree_complexity_regularization=(
                            regularization_config.tree_complexity),
                        min_node_weight=min_node_weight,
                        feature_column_group_id=dense_float_column_idx,
                        epsilon=0.01,
                        num_quantiles=100,
                        dense_float_column=self.
                        _dense_floats[dense_float_column_idx],
                        name=fc_name,
                        gradient_shape=gradient_shape,
                        hessian_shape=hessian_shape,
                        multiclass_strategy=strategy,
                        init_stamp_token=init_stamp_token))
                fc_name_idx += 1

            # Create handlers for sparse float columns.
            for sparse_float_column_idx in range(
                    len(self._sparse_float_indices)):
                fc_name = self._fc_names[fc_name_idx]
                handlers.append(
                    ordinal_split_handler.SparseSplitHandler(
                        l1_regularization=regularization_config.l1,
                        l2_regularization=regularization_config.l2,
                        tree_complexity_regularization=(
                            regularization_config.tree_complexity),
                        min_node_weight=min_node_weight,
                        feature_column_group_id=sparse_float_column_idx,
                        epsilon=0.01,
                        num_quantiles=100,
                        sparse_float_column=sparse_tensor.SparseTensor(
                            self.
                            _sparse_float_indices[sparse_float_column_idx],
                            self._sparse_float_values[sparse_float_column_idx],
                            self._sparse_float_shapes[sparse_float_column_idx]
                        ),
                        name=fc_name,
                        gradient_shape=gradient_shape,
                        hessian_shape=hessian_shape,
                        multiclass_strategy=strategy,
                        init_stamp_token=init_stamp_token))
                fc_name_idx += 1

            # Create handlers for sparse int columns.
            for sparse_int_column_idx in range(len(self._sparse_int_indices)):
                fc_name = self._fc_names[fc_name_idx]
                handlers.append(
                    categorical_split_handler.EqualitySplitHandler(
                        l1_regularization=regularization_config.l1,
                        l2_regularization=regularization_config.l2,
                        tree_complexity_regularization=(
                            regularization_config.tree_complexity),
                        min_node_weight=min_node_weight,
                        feature_column_group_id=sparse_int_column_idx,
                        sparse_int_column=sparse_tensor.SparseTensor(
                            self._sparse_int_indices[sparse_int_column_idx],
                            self._sparse_int_values[sparse_int_column_idx],
                            self._sparse_int_shapes[sparse_int_column_idx]),
                        name=fc_name,
                        gradient_shape=gradient_shape,
                        hessian_shape=hessian_shape,
                        multiclass_strategy=strategy,
                        init_stamp_token=init_stamp_token))
                fc_name_idx += 1

            # Create steps accumulator.
            steps_accumulator = stats_accumulator_ops.StatsAccumulator(
                stamp_token=0,
                gradient_shape=tensor_shape.scalar(),
                hessian_shape=tensor_shape.scalar(),
                name="StepsAccumulator")

            # Create bias stats accumulator.
            bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator(
                stamp_token=0,
                gradient_shape=gradient_shape,
                hessian_shape=hessian_shape,
                name="BiasAccumulator")

            # Create ensemble stats variables.
            num_layer_examples = variables.Variable(
                initial_value=array_ops.zeros([], dtypes.int64),
                name="num_layer_examples",
                trainable=False)
            num_layer_steps = variables.Variable(initial_value=array_ops.zeros(
                [], dtypes.int64),
                                                 name="num_layer_steps",
                                                 trainable=False)
            num_layers = variables.Variable(initial_value=array_ops.zeros(
                [], dtypes.int64),
                                            name="num_layers",
                                            trainable=False)
            active_tree = variables.Variable(initial_value=array_ops.zeros(
                [], dtypes.int64),
                                             name="active_tree",
                                             trainable=False)
            active_layer = variables.Variable(initial_value=array_ops.zeros(
                [], dtypes.int64),
                                              name="active_layer",
                                              trainable=False)

        # Create ensemble stats summaries.
        summary.scalar("layer_stats/num_examples", num_layer_examples)
        summary.scalar("layer_stats/num_steps", num_layer_steps)
        summary.scalar("ensemble_stats/active_tree", active_tree)
        summary.scalar("ensemble_stats/active_layer", active_layer)

        # Update bias stats.
        stats_update_ops = []
        continue_centering = variables.Variable(
            initial_value=self._center_bias,
            name="continue_centering",
            trainable=False)
        stats_update_ops.append(
            control_flow_ops.cond(
                continue_centering,
                self._make_update_bias_stats_fn(ensemble_stamp, predictions,
                                                gradients,
                                                bias_stats_accumulator),
                control_flow_ops.no_op))

        # Update handler stats.
        handler_reads = {}
        for handler in handlers:
            handler_reads[handler] = handler.scheduled_reads()

        handler_results = batch_ops_utils.run_handler_scheduled_ops(
            handler_reads, ensemble_stamp, worker_device)
        per_handler_updates = {}
        # Two values per handler. First one is if the handler is active for the
        # current layer. The second one is if the handler is going to be active
        # for the next layer.
        subsampling_type = self._learner_config.WhichOneof("feature_fraction")
        if subsampling_type == "feature_fraction_per_level":
            seed = predictions_dict[NUM_LAYERS_ATTEMPTED]
            active_handlers_current_layer = stateless.stateless_random_uniform(
                shape=[len(handlers)], seed=[seed, 1])
            active_handlers_next_layer = stateless.stateless_random_uniform(
                shape=[len(handlers)], seed=[seed + 1, 1])
            active_handlers = array_ops.stack(
                [active_handlers_current_layer, active_handlers_next_layer],
                axis=1)
            active_handlers = (active_handlers <
                               self._learner_config.feature_fraction_per_level)
        elif subsampling_type == "feature_fraction_per_tree":
            seed = predictions_dict[NUM_TREES_ATTEMPTED]
            active_handlers_current_layer = stateless.stateless_random_uniform(
                shape=[len(handlers)], seed=[seed, 2])
            active_handlers_current_layer = (
                active_handlers_current_layer <
                self._learner_config.feature_fraction_per_tree)
            active_handlers = array_ops.stack(
                active_handlers_current_layer,
                array_ops.ones([len(handlers)], dtype=dtypes.bool))
        else:
            active_handlers = array_ops.ones([len(handlers), 2],
                                             dtype=dtypes.bool)

        # Prepare empty gradients and hessians when handlers are not ready.
        empty_hess_shape = [1] + hessian_shape.as_list()
        empty_grad_shape = [1] + gradient_shape.as_list()

        empty_gradients = constant_op.constant([],
                                               dtype=dtypes.float32,
                                               shape=empty_grad_shape)
        empty_hessians = constant_op.constant([],
                                              dtype=dtypes.float32,
                                              shape=empty_hess_shape)

        for handler_idx in range(len(handlers)):
            handler = handlers[handler_idx]
            is_active = active_handlers[handler_idx]
            updates, scheduled_updates = handler.update_stats(
                ensemble_stamp, partition_ids, squeezed_gradients,
                squeezed_hessians, empty_gradients, empty_hessians, weights,
                is_active, handler_results[handler])
            stats_update_ops.append(updates)
            per_handler_updates[handler] = scheduled_updates

        update_results = batch_ops_utils.run_handler_scheduled_ops(
            per_handler_updates, ensemble_stamp, worker_device)
        for update in update_results.values():
            stats_update_ops += update
        # Accumulate a step after updating stats.
        batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32)
        with ops.control_dependencies(stats_update_ops):
            add_step_op = steps_accumulator.add(ensemble_stamp, [0], [[0, 0]],
                                                [batch_size], [1.0])

        # Determine learning rate.
        learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof(
            "tuner")
        if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout":
            tuner = getattr(self._learner_config.learning_rate_tuner,
                            learning_rate_tuner)
            learning_rate = tuner.learning_rate
        else:
            # TODO (nponomareva, soroush) do the line search. id:498 gh:499
            raise ValueError("Line search learning rate is not yet supported.")

        # After adding the step, decide if further processing is needed.
        ensemble_update_ops = [add_step_op]
        with ops.control_dependencies([add_step_op]):
            if self._is_chief:
                dropout_seed = predictions_dict[NUM_TREES_ATTEMPTED]

                # Get accumulated steps and examples for the current layer.
                _, _, _, _, acc_examples, acc_steps = steps_accumulator.serialize(
                )
                acc_examples = math_ops.cast(acc_examples[0], dtypes.int64)
                acc_steps = math_ops.cast(acc_steps[0], dtypes.int64)
                ensemble_update_ops.append(
                    num_layer_examples.assign(acc_examples))
                ensemble_update_ops.append(num_layer_steps.assign(acc_steps))
                # Determine whether we need to update tree ensemble.
                examples_per_layer = self._examples_per_layer
                if callable(examples_per_layer):
                    examples_per_layer = examples_per_layer(active_layer)
                ensemble_update_ops.append(
                    control_flow_ops.cond(
                        acc_examples >= examples_per_layer,
                        self._make_update_ensemble_fn(
                            ensemble_stamp, steps_accumulator,
                            bias_stats_accumulator, continue_centering,
                            learning_rate, handlers, num_layers, active_tree,
                            active_layer, dropout_seed, class_id),
                        control_flow_ops.no_op))

        # Calculate the loss to be reported.
        # Note, the loss is calculated from the prediction considering dropouts, so
        # that the value might look staggering over steps when the dropout ratio is
        # high. eval_loss might be referred instead in the aspect of convergence.
        return control_flow_ops.group(*ensemble_update_ops)