コード例 #1
0
  def testCachedPredictionOnEmptyEnsemble(self):
    """Tests that prediction on a dummy ensemble does not fail."""
    with self.cached_session() as session:
      # Create a dummy ensemble.
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto='')
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # No previous cached values.
      cached_tree_ids = [0, 0]
      cached_node_ids = [0, 0]

      # We have two features: 0 and 1. Values don't matter here on a dummy
      # ensemble.
      feature_0_values = [67, 5]
      feature_1_values = [9, 17]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # Nothing changed.
      self.assertAllClose(cached_tree_ids, new_tree_ids)
      self.assertAllClose(cached_node_ids, new_node_ids)
      self.assertAllClose([[0], [0]], logits_updates)
コード例 #2
0
  def testCachedPredictionFromPreviousTree(self):
    """Tests the predictions work when we have cache from previous trees."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7
            }
          }
          nodes {
            leaf {
              scalar: 5
            }
          }
          nodes {
            leaf {
              scalar: 6
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: false
        }
        tree_weights: 0.1
        tree_weights: 0.1
        tree_weights: 0.1
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Two examples, one was cached in node 1 first, another in node 2.
      cached_tree_ids = [0, 0]
      cached_node_ids = [1, 0]

      # We have two features: 0 and 1.
      feature_0_values = [36, 32]
      feature_1_values = [11, 27]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
      # Example 1 will get to node 3 in tree 1 and node 2 of tree 2
      # Example 2 will get to node 2 in tree 1 and node 1 of tree 2

      # We are in the last tree.
      self.assertAllClose([2, 2], new_tree_ids)
      # When using the full tree, the first example will end up in node 4,
      # the second in node 5.
      self.assertAllClose([2, 1], new_node_ids)
      # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = >
      #            change = 0.1*(5.0+5.0)
      # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
      #            change= 0.1(1.14+7.0-7.0)
      self.assertAllClose([[1], [0.114]], logits_updates)
コード例 #3
0
  def testCachedPredictionIsCurrent(self):
    """Tests that prediction based on previous node in the tree works."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Two examples, one was cached in node 1 first, another in node 0.
      cached_tree_ids = [0, 0]
      cached_node_ids = [1, 2]

      # We have two features: 0 and 1. Values don't matter because trees didn't
      # change.
      feature_0_values = [67, 5]
      feature_1_values = [9, 17]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # Nothing changed.
      self.assertAllClose(cached_tree_ids, new_tree_ids)
      self.assertAllClose(cached_node_ids, new_node_ids)
      self.assertAllClose([[0], [0]], logits_updates)
コード例 #4
0
  def testCachedPredictionFromTheSameTree(self):
    """Tests that prediction based on previous node in the tree works."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 7
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 1.4
              original_leaf {
                scalar: 7.14
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 7
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 2.7
              original_leaf {
                scalar: -4.375
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
          nodes {
            leaf {
              scalar: -5.875
            }
          }
          nodes {
            leaf {
              scalar: -2.075
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Two examples, one was cached in node 1 first, another in node 0.
      cached_tree_ids = [0, 0]
      cached_node_ids = [1, 0]

      # We have two features: 0 and 1.
      feature_0_values = [67, 5]
      feature_1_values = [9, 17]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # We are still in the same tree.
      self.assertAllClose([0, 0], new_tree_ids)
      # When using the full tree, the first example will end up in node 4,
      # the second in node 5.
      self.assertAllClose([4, 5], new_node_ids)
      # Full predictions for each instance would be 8.79 and -5.875,
      # so an update from the previous cached values lr*(7.14 and -2) would be
      # 1.65 and -3.875, and then multiply them by 0.1 (lr)
      self.assertAllClose([[0.1 * 1.65], [0.1 * -3.875]], logits_updates)
コード例 #5
0
def _bt_model_fn(
    features,
    labels,
    mode,
    head,
    feature_columns,
    tree_hparams,
    n_batches_per_layer,
    config,
    closed_form_grad_and_hess_fn=None,
    example_id_column_name=None,
    # TODO(youngheek): replace this later using other options.
    train_in_memory=False,
    name='boosted_trees'):
  """Gradient Boosted Trees model_fn.

  Args:
    features: dict of `Tensor`.
    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
      dtype `int32` or `int64` in the range `[0, n_classes)`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    head: A `head_lib._Head` instance.
    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
    tree_hparams: TODO. collections.namedtuple for hyper parameters.
    n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
      least n_batches_per_layer accumulations.
    config: `RunConfig` object to configure the runtime settings.
    closed_form_grad_and_hess_fn: a function that accepts logits and labels
      and returns gradients and hessians. By default, they are created by
      tf.gradients() from the loss.
    example_id_column_name: Name of the feature for a unique ID per example.
      Currently experimental -- not exposed to public API.
    train_in_memory: `bool`, when true, it assumes the dataset is in memory,
      i.e., input_fn should return the entire dataset as a single batch, and
      also n_batches_per_layer should be set as 1.
    name: Name to use for the model.

  Returns:
      An `EstimatorSpec` instance.

  Raises:
    ValueError: mode or params are invalid, or features has the wrong type.
  """
  is_single_machine = (config.num_worker_replicas <= 1)
  sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
  center_bias = tree_hparams.center_bias

  if train_in_memory:
    assert n_batches_per_layer == 1, (
        'When train_in_memory is enabled, input_fn should return the entire '
        'dataset as a single batch, and n_batches_per_layer should be set as '
        '1.')
    if (not config.is_chief or config.num_worker_replicas > 1 or
        config.num_ps_replicas > 0):
      raise ValueError('train_in_memory is supported only for '
                       'non-distributed training.')
  worker_device = control_flow_ops.no_op().device
  train_op = []
  with ops.name_scope(name) as name:
    # Prepare.
    global_step = training_util.get_or_create_global_step()
    bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
        sorted_feature_columns)
    # Create Ensemble resources.
    tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)

    # Create logits.
    if mode != model_fn.ModeKeys.TRAIN:
      input_feature_list = _get_transformed_features(features,
                                                     sorted_feature_columns)
      logits = boosted_trees_ops.predict(
          # For non-TRAIN mode, ensemble doesn't change after initialization,
          # so no local copy is needed; using tree_ensemble directly.
          tree_ensemble_handle=tree_ensemble.resource_handle,
          bucketized_features=input_feature_list,
          logits_dimension=head.logits_dimension)
      return head.create_estimator_spec(
          features=features,
          mode=mode,
          labels=labels,
          train_op_fn=control_flow_ops.no_op,
          logits=logits)

    # ============== Training graph ==============
    # Extract input features and set up cache for training.
    training_state_cache = None
    if train_in_memory:
      # cache transformed features as well for in-memory training.
      batch_size = array_ops.shape(labels)[0]
      input_feature_list, input_cache_op = (
          _cache_transformed_features(features, sorted_feature_columns,
                                      batch_size))
      train_op.append(input_cache_op)
      training_state_cache = _CacheTrainingStatesUsingVariables(
          batch_size, head.logits_dimension)
    else:
      input_feature_list = _get_transformed_features(features,
                                                     sorted_feature_columns)
      if example_id_column_name:
        example_ids = features[example_id_column_name]
        training_state_cache = _CacheTrainingStatesUsingHashTable(
            example_ids, head.logits_dimension)

    # Variable that determines whether bias centering is needed.
    center_bias_var = variable_scope.variable(
        initial_value=center_bias, name='center_bias_needed', trainable=False)
    if is_single_machine:
      local_tree_ensemble = tree_ensemble
      ensemble_reload = control_flow_ops.no_op()
    else:
      # Have a local copy of ensemble for the distributed setting.
      with ops.device(worker_device):
        local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
            name=name + '_local', is_local=True)
      # TODO(soroush): Do partial updates if this becomes a bottleneck.
      ensemble_reload = local_tree_ensemble.deserialize(
          *tree_ensemble.serialize())

    if training_state_cache:
      cached_tree_ids, cached_node_ids, cached_logits = (
          training_state_cache.lookup())
    else:
      # Always start from the beginning when no cache is set up.
      batch_size = array_ops.shape(labels)[0]
      cached_tree_ids, cached_node_ids, cached_logits = (
          array_ops.zeros([batch_size], dtype=dtypes.int32),
          _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
          array_ops.zeros(
              [batch_size, head.logits_dimension], dtype=dtypes.float32))

    with ops.control_dependencies([ensemble_reload]):
      (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
       last_layer_nodes_range) = local_tree_ensemble.get_states()
      summary.scalar('ensemble/num_trees', num_trees)
      summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
      summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)

      partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
          tree_ensemble_handle=local_tree_ensemble.resource_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=input_feature_list,
          logits_dimension=head.logits_dimension)
      logits = cached_logits + partial_logits

    # Create training graph.
    def _train_op_fn(loss):
      """Run one training iteration."""
      if training_state_cache:
        # Cache logits only after center_bias is complete, if it's in progress.
        train_op.append(
            control_flow_ops.cond(
                center_bias_var, control_flow_ops.no_op,
                lambda: training_state_cache.insert(tree_ids, node_ids, logits))
        )

      if closed_form_grad_and_hess_fn:
        gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
      else:
        gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0]
        hessians = gradients_impl.gradients(
            gradients, logits, name='Hessians')[0]

      # TODO(youngheek): perhaps storage could be optimized by storing stats
      # with the dimension max_splits_per_layer, instead of max_splits (for the
      # entire tree).
      max_splits = _get_max_splits(tree_hparams)

      stats_summaries_list = []
      for i, feature_ids in enumerate(feature_ids_list):
        num_buckets = bucket_size_list[i]
        summaries = [
            array_ops.squeeze(
                boosted_trees_ops.make_stats_summary(
                    node_ids=node_ids,
                    gradients=gradients,
                    hessians=hessians,
                    bucketized_features_list=[input_feature_list[f]],
                    max_splits=max_splits,
                    num_buckets=num_buckets),
                axis=0) for f in feature_ids
        ]
        stats_summaries_list.append(summaries)

      if train_in_memory and is_single_machine:
        grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
      else:
        grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
                                            stamp_token, n_batches_per_layer,
                                            bucket_size_list, config.is_chief)

      update_model = control_flow_ops.cond(
          center_bias_var,
          functools.partial(
              grower.center_bias,
              center_bias_var,
              gradients,
              hessians,
          ),
          functools.partial(grower.grow_tree, stats_summaries_list,
                            feature_ids_list, last_layer_nodes_range))
      train_op.append(update_model)

      with ops.control_dependencies([update_model]):
        increment_global = distribute_lib.increment_var(global_step)
        train_op.append(increment_global)

      return control_flow_ops.group(train_op, name='train_op')

  estimator_spec = head.create_estimator_spec(
      features=features,
      mode=mode,
      labels=labels,
      train_op_fn=_train_op_fn,
      logits=logits)
  # Add an early stop hook.
  estimator_spec = estimator_spec._replace(
      training_hooks=estimator_spec.training_hooks +
      (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
                           tree_hparams.n_trees, tree_hparams.max_depth),))
  return estimator_spec
コード例 #6
0
def _bt_model_fn(
    features,
    labels,
    mode,
    head,
    feature_columns,
    tree_hparams,
    n_batches_per_layer,
    config,
    closed_form_grad_and_hess_fn=None,
    example_id_column_name=None,
    # TODO(youngheek): replace this later using other options.
    train_in_memory=False,
    name='boosted_trees'):
  """Gradient Boosted Trees model_fn.

  Args:
    features: dict of `Tensor`.
    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
      dtype `int32` or `int64` in the range `[0, n_classes)`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    head: A `head_lib._Head` instance.
    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
    tree_hparams: TODO. collections.namedtuple for hyper parameters.
    n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
      least n_batches_per_layer accumulations.
    config: `RunConfig` object to configure the runtime settings.
    closed_form_grad_and_hess_fn: a function that accepts logits and labels
      and returns gradients and hessians. By default, they are created by
      tf.gradients() from the loss.
    example_id_column_name: Name of the feature for a unique ID per example.
      Currently experimental -- not exposed to public API.
    train_in_memory: `bool`, when true, it assumes the dataset is in memory,
      i.e., input_fn should return the entire dataset as a single batch, and
      also n_batches_per_layer should be set as 1.
    name: Name to use for the model.

  Returns:
      An `EstimatorSpec` instance.

  Raises:
    ValueError: mode or params are invalid, or features has the wrong type.
  """
  is_single_machine = (config.num_worker_replicas <= 1)

  sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
  if train_in_memory:
    assert n_batches_per_layer == 1, (
        'When train_in_memory is enabled, input_fn should return the entire '
        'dataset as a single batch, and n_batches_per_layer should be set as '
        '1.')
    if (not config.is_chief or config.num_worker_replicas > 1 or
        config.num_ps_replicas > 0):
      raise ValueError('train_in_memory is supported only for '
                       'non-distributed training.')
  worker_device = control_flow_ops.no_op().device
  # maximum number of splits possible in the whole tree =2^(D-1)-1
  # TODO(youngheek): perhaps storage could be optimized by storing stats with
  # the dimension max_splits_per_layer, instead of max_splits (for the entire
  # tree).
  max_splits = (1 << tree_hparams.max_depth) - 1
  train_op = []
  with ops.name_scope(name) as name:
    # Prepare.
    global_step = training_util.get_or_create_global_step()
    bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
        sorted_feature_columns)
    # Extract input features and set up cache for training.
    training_state_cache = None
    if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
      # cache transformed features as well for in-memory training.
      batch_size = array_ops.shape(labels)[0]
      input_feature_list, input_cache_op = (
          _cache_transformed_features(features, sorted_feature_columns,
                                      batch_size))
      train_op.append(input_cache_op)
      training_state_cache = _CacheTrainingStatesUsingVariables(
          batch_size, head.logits_dimension)
    else:
      input_feature_list = _get_transformed_features(features,
                                                     sorted_feature_columns)
      if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
        example_ids = features[example_id_column_name]
        training_state_cache = _CacheTrainingStatesUsingHashTable(
            example_ids, head.logits_dimension)

    # Create Ensemble resources.
    tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
    # Create logits.
    if mode != model_fn.ModeKeys.TRAIN:
      logits = boosted_trees_ops.predict(
          # For non-TRAIN mode, ensemble doesn't change after initialization,
          # so no local copy is needed; using tree_ensemble directly.
          tree_ensemble_handle=tree_ensemble.resource_handle,
          bucketized_features=input_feature_list,
          logits_dimension=head.logits_dimension)
    else:
      if is_single_machine:
        local_tree_ensemble = tree_ensemble
        ensemble_reload = control_flow_ops.no_op()
      else:
        # Have a local copy of ensemble for the distributed setting.
        with ops.device(worker_device):
          local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
              name=name + '_local', is_local=True)
        # TODO(soroush): Do partial updates if this becomes a bottleneck.
        ensemble_reload = local_tree_ensemble.deserialize(
            *tree_ensemble.serialize())
      if training_state_cache:
        cached_tree_ids, cached_node_ids, cached_logits = (
            training_state_cache.lookup())
      else:
        # Always start from the beginning when no cache is set up.
        batch_size = array_ops.shape(labels)[0]
        cached_tree_ids, cached_node_ids, cached_logits = (
            array_ops.zeros([batch_size], dtype=dtypes.int32),
            array_ops.zeros([batch_size], dtype=dtypes.int32),
            array_ops.zeros(
                [batch_size, head.logits_dimension], dtype=dtypes.float32))
      with ops.control_dependencies([ensemble_reload]):
        (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
         last_layer_nodes_range) = local_tree_ensemble.get_states()
        summary.scalar('ensemble/num_trees', num_trees)
        summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
        summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)

        partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
            tree_ensemble_handle=local_tree_ensemble.resource_handle,
            cached_tree_ids=cached_tree_ids,
            cached_node_ids=cached_node_ids,
            bucketized_features=input_feature_list,
            logits_dimension=head.logits_dimension)
      logits = cached_logits + partial_logits

    # Create training graph.
    def _train_op_fn(loss):
      """Run one training iteration."""
      if training_state_cache:
        train_op.append(training_state_cache.insert(tree_ids, node_ids, logits))
      if closed_form_grad_and_hess_fn:
        gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
      else:
        gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0]
        hessians = gradients_impl.gradients(
            gradients, logits, name='Hessians')[0]

      stats_summaries_list = []
      for i, feature_ids in enumerate(feature_ids_list):
        num_buckets = bucket_size_list[i]
        summaries = [
            array_ops.squeeze(
                boosted_trees_ops.make_stats_summary(
                    node_ids=node_ids,
                    gradients=gradients,
                    hessians=hessians,
                    bucketized_features_list=[input_feature_list[f]],
                    max_splits=max_splits,
                    num_buckets=num_buckets),
                axis=0) for f in feature_ids
        ]
        stats_summaries_list.append(summaries)

      accumulators = []

      def grow_tree_from_stats_summaries(stats_summaries_list,
                                         feature_ids_list):
        """Updates ensemble based on the best gains from stats summaries."""
        node_ids_per_feature = []
        gains_list = []
        thresholds_list = []
        left_node_contribs_list = []
        right_node_contribs_list = []
        all_feature_ids = []

        assert len(stats_summaries_list) == len(feature_ids_list)

        for i, feature_ids in enumerate(feature_ids_list):
          (numeric_node_ids_per_feature, numeric_gains_list,
           numeric_thresholds_list, numeric_left_node_contribs_list,
           numeric_right_node_contribs_list) = (
               boosted_trees_ops.calculate_best_gains_per_feature(
                   node_id_range=last_layer_nodes_range,
                   stats_summary_list=stats_summaries_list[i],
                   l1=tree_hparams.l1,
                   l2=tree_hparams.l2,
                   tree_complexity=tree_hparams.tree_complexity,
                   min_node_weight=tree_hparams.min_node_weight,
                   max_splits=max_splits))

          all_feature_ids += feature_ids
          node_ids_per_feature += numeric_node_ids_per_feature
          gains_list += numeric_gains_list
          thresholds_list += numeric_thresholds_list
          left_node_contribs_list += numeric_left_node_contribs_list
          right_node_contribs_list += numeric_right_node_contribs_list

        grow_op = boosted_trees_ops.update_ensemble(
            # Confirm if local_tree_ensemble or tree_ensemble should be used.
            tree_ensemble.resource_handle,
            feature_ids=all_feature_ids,
            node_ids=node_ids_per_feature,
            gains=gains_list,
            thresholds=thresholds_list,
            left_node_contribs=left_node_contribs_list,
            right_node_contribs=right_node_contribs_list,
            learning_rate=tree_hparams.learning_rate,
            max_depth=tree_hparams.max_depth,
            pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
        return grow_op

      if train_in_memory and is_single_machine:
        train_op.append(distribute_lib.increment_var(global_step))
        train_op.append(
            grow_tree_from_stats_summaries(stats_summaries_list,
                                           feature_ids_list))
      else:
        dependencies = []

        for i, feature_ids in enumerate(feature_ids_list):
          stats_summaries = stats_summaries_list[i]
          accumulator = data_flow_ops.ConditionalAccumulator(
              dtype=dtypes.float32,
              # The stats consist of grads and hessians (the last dimension).
              shape=[len(feature_ids), max_splits, bucket_size_list[i], 2],
              shared_name='numeric_stats_summary_accumulator_' + str(i))
          accumulators.append(accumulator)

          apply_grad = accumulator.apply_grad(
              array_ops.stack(stats_summaries, axis=0), stamp_token)
          dependencies.append(apply_grad)

        def grow_tree_from_accumulated_summaries_fn():
          """Updates the tree with the best layer from accumulated summaries."""
          # Take out the accumulated summaries from the accumulator and grow.
          stats_summaries_list = []

          stats_summaries_list = [
              array_ops.unstack(accumulator.take_grad(1), axis=0)
              for accumulator in accumulators
          ]

          grow_op = grow_tree_from_stats_summaries(stats_summaries_list,
                                                   feature_ids_list)
          return grow_op

        with ops.control_dependencies(dependencies):
          train_op.append(distribute_lib.increment_var(global_step))
          if config.is_chief:
            min_accumulated = math_ops.reduce_min(
                array_ops.stack(
                    [acc.num_accumulated() for acc in accumulators]))

            train_op.append(
                control_flow_ops.cond(
                    math_ops.greater_equal(min_accumulated,
                                           n_batches_per_layer),
                    grow_tree_from_accumulated_summaries_fn,
                    control_flow_ops.no_op,
                    name='wait_until_n_batches_accumulated'))

      return control_flow_ops.group(train_op, name='train_op')

  estimator_spec = head.create_estimator_spec(
      features=features,
      mode=mode,
      labels=labels,
      train_op_fn=_train_op_fn,
      logits=logits)
  if mode == model_fn.ModeKeys.TRAIN:
    # Add an early stop hook.
    estimator_spec = estimator_spec._replace(
        training_hooks=estimator_spec.training_hooks +
        (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
                             tree_hparams.n_trees, tree_hparams.max_depth),))
  return estimator_spec
コード例 #7
0
    def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id:0
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 5
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 0.5
              original_leaf {
                scalar: 0.0143
               }
            }
          }
          nodes {
            leaf {
              scalar: 0.0553
            }
          }
          nodes {
            leaf {
              scalar: 0.0783
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.55
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 3
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 2
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.07
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.083
          }
          post_pruned_nodes_meta {
            new_node_id: 3
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 4
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.22
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.57
          }
        }
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 2
          num_layers_attempted: 4
        }
      """, tree_ensemble_config)

            # Create existing ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            cached_tree_ids = [0, 0, 0, 0, 0, 0]
            # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
            # changed the ids to 3 and 4 respectively.
            cached_node_ids = [3, 4, 5, 6, 7, 8]

            # We have two features: 0 and 1.
            feature_0_values = [12, 17, 35, 36, 23, 11]
            feature_1_values = [12, 12, 17, 18, 123, 24]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the last tree.
            self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids)
            # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
            # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in
            # the root node.
            self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids)

            cached_values = [[0.08], [0.093], [0.0553], [0.0783],
                             [0.15 + 0.08], [0.5 + 0.08]]
            root = 0.55
            self.assertAllClose(
                [[root + 0.01], [root + 0.01], [root + 0.0553],
                 [root + 0.0783], [root + 0.01], [root + 0.01]],
                logits_updates + cached_values)
コード例 #8
0
  def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
    """Tests that prediction based on previous node in the tree works."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id:0
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 5
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 0.5
              original_leaf {
                scalar: 0.0143
               }
            }
          }
          nodes {
            leaf {
              scalar: 0.0553
            }
          }
          nodes {
            leaf {
              scalar: 0.0783
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.55
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 3
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 2
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.07
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.083
          }
          post_pruned_nodes_meta {
            new_node_id: 3
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 4
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.22
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.57
          }
        }
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 2
          num_layers_attempted: 4
        }
      """, tree_ensemble_config)

      # Create existing ensemble.
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      cached_tree_ids = [0, 0, 0, 0, 0, 0]
      # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
      # changed the ids to 3 and 4 respectively.
      cached_node_ids = [3, 4, 5, 6, 7, 8]

      # We have two features: 0 and 1.
      feature_0_values = [12, 17, 35, 36, 23, 11]
      feature_1_values = [12, 12, 17, 18, 123, 24]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # We are in the last tree.
      self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids)
      # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
      # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in
      # the root node.
      self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids)

      cached_values = [[0.08], [0.093], [0.0553], [0.0783], [0.15 + 0.08],
                       [0.5 + 0.08]]
      root = 0.55
      self.assertAllClose([[root + 0.01], [root + 0.01], [root + 0.0553],
                           [root + 0.0783], [root + 0.01], [root + 0.01]],
                          logits_updates + cached_values)
コード例 #9
0
    def testCategoricalSplits(self):
        """Tests the training prediction work for categorical splits."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            categorical_split {
              feature_id: 1
              value: 2
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            categorical_split {
              feature_id: 0
              value: 13
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          is_finalized: true
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [13, 1, 3]
            feature_1_values = [2, 2, 1]

            # No previous cached values.
            cached_tree_ids = [0, 0, 0]
            cached_node_ids = [0, 0, 0]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            self.assertAllClose([0, 0, 0], new_tree_ids)
            self.assertAllClose([3, 4, 2], new_node_ids)
            self.assertAllClose([[5.], [6.], [7.]], logits_updates)
コード例 #10
0
    def testNoCachedPredictionButTreeExists(self):
        """Tests that predictions are updated once trees are added."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 1
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, none were cached before.
            cached_tree_ids = [0, 0]
            cached_node_ids = [0, 0]

            feature_0_values = [67, 5]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the first tree.
            self.assertAllClose([0, 0], new_tree_ids)
            self.assertAllClose([2, 1], new_node_ids)
            self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
コード例 #11
0
    def testCachedPredictionFromPreviousTree(self):
        """Tests the predictions work when we have cache from previous trees."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7
            }
          }
          nodes {
            leaf {
              scalar: 5
            }
          }
          nodes {
            leaf {
              scalar: 6
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: false
        }
        tree_weights: 0.1
        tree_weights: 0.1
        tree_weights: 0.1
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 2.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 0]

            # We have two features: 0 and 1.
            feature_0_values = [36, 32]
            feature_1_values = [11, 27]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)
            # Example 1 will get to node 3 in tree 1 and node 2 of tree 2
            # Example 2 will get to node 2 in tree 1 and node 1 of tree 2

            # We are in the last tree.
            self.assertAllClose([2, 2], new_tree_ids)
            # When using the full tree, the first example will end up in node 4,
            # the second in node 5.
            self.assertAllClose([2, 1], new_node_ids)
            # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = >
            #            change = 0.1*(5.0+5.0)
            # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
            #            change= 0.1(1.14+7.0-7.0)
            self.assertAllClose([[1], [0.114]], logits_updates)
コード例 #12
0
    def testCachedPredictionFromTheSameTree(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 7
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 1.4
              original_leaf {
                scalar: 7.14
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 7
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 2.7
              original_leaf {
                scalar: -4.375
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
          nodes {
            leaf {
              scalar: -5.875
            }
          }
          nodes {
            leaf {
              scalar: -2.075
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 0.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 0]

            # We have two features: 0 and 1.
            feature_0_values = [67, 5]
            feature_1_values = [9, 17]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are still in the same tree.
            self.assertAllClose([0, 0], new_tree_ids)
            # When using the full tree, the first example will end up in node 4,
            # the second in node 5.
            self.assertAllClose([4, 5], new_node_ids)
            # Full predictions for each instance would be 8.79 and -5.875,
            # so an update from the previous cached values lr*(7.14 and -2) would be
            # 1.65 and -3.875, and then multiply them by 0.1 (lr)
            self.assertAllClose([[0.1 * 1.65], [0.1 * -3.875]], logits_updates)
コード例 #13
0
    def testCachedPredictionIsCurrent(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 0.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 2]

            # We have two features: 0 and 1. Values don't matter because trees didn't
            # change.
            feature_0_values = [67, 5]
            feature_1_values = [9, 17]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # Nothing changed.
            self.assertAllClose(cached_tree_ids, new_tree_ids)
            self.assertAllClose(cached_node_ids, new_node_ids)
            self.assertAllClose([[0], [0]], logits_updates)
コード例 #14
0
  def testCategoricalSplits(self):
    """Tests the training prediction work for categorical splits."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            categorical_split {
              feature_id: 1
              value: 2
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            categorical_split {
              feature_id: 0
              value: 13
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          is_finalized: true
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_0_values = [13, 1, 3]
      feature_1_values = [2, 2, 1]

      # No previous cached values.
      cached_tree_ids = [0, 0, 0]
      cached_node_ids = [0, 0, 0]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      self.assertAllClose([0, 0, 0], new_tree_ids)
      self.assertAllClose([3, 4, 2], new_node_ids)
      self.assertAllClose([[5.], [6.], [7.]], logits_updates)
コード例 #15
0
    def testCachedPredictionTheWholeTreeWasPruned(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 0.00
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: -6.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 5.0
          }
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

            # Create existing ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            cached_tree_ids = [
                0,
                0,
            ]
            # The predictions were cached in 1 and 2, both were pruned to the root.
            cached_node_ids = [1, 2]

            # We have two features: 0 and 1.These are not going to be used anywhere.
            feature_0_values = [12, 17]
            feature_1_values = [12, 12]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the last tree.
            self.assertAllClose([0, 0], new_tree_ids)
            self.assertAllClose([0, 0], new_node_ids)

            self.assertAllClose([[-6.0], [5.0]], logits_updates)
コード例 #16
0
  def testNoCachedPredictionButTreeExists(self):
    """Tests that predictions are updated once trees are added."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 1
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Two examples, none were cached before.
      cached_tree_ids = [0, 0]
      cached_node_ids = [0, 0]

      feature_0_values = [67, 5]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # We are in the first tree.
      self.assertAllClose([0, 0], new_tree_ids)
      self.assertAllClose([2, 1], new_node_ids)
      self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
コード例 #17
0
def _bt_model_fn(
        features,
        labels,
        mode,
        head,
        feature_columns,
        tree_hparams,
        n_batches_per_layer,
        config,
        closed_form_grad_and_hess_fn=None,
        example_id_column_name=None,
        # TODO(youngheek): replace this later using other options.
        train_in_memory=False,
        name='boosted_trees'):
    """Gradient Boosted Trees model_fn.

  Args:
    features: dict of `Tensor`.
    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
      dtype `int32` or `int64` in the range `[0, n_classes)`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    head: A `head_lib._Head` instance.
    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
    tree_hparams: TODO. collections.namedtuple for hyper parameters.
    n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
      least n_batches_per_layer accumulations.
    config: `RunConfig` object to configure the runtime settings.
    closed_form_grad_and_hess_fn: a function that accepts logits and labels
      and returns gradients and hessians. By default, they are created by
      tf.gradients() from the loss.
    example_id_column_name: Name of the feature for a unique ID per example.
      Currently experimental -- not exposed to public API.
    train_in_memory: `bool`, when true, it assumes the dataset is in memory,
      i.e., input_fn should return the entire dataset as a single batch, and
      also n_batches_per_layer should be set as 1.
    name: Name to use for the model.

  Returns:
      An `EstimatorSpec` instance.

  Raises:
    ValueError: mode or params are invalid, or features has the wrong type.
  """
    is_single_machine = (config.num_worker_replicas <= 1)

    sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
    if train_in_memory:
        assert n_batches_per_layer == 1, (
            'When train_in_memory is enabled, input_fn should return the entire '
            'dataset as a single batch, and n_batches_per_layer should be set as '
            '1.')
        if (not config.is_chief or config.num_worker_replicas > 1
                or config.num_ps_replicas > 0):
            raise ValueError('train_in_memory is supported only for '
                             'non-distributed training.')
    worker_device = control_flow_ops.no_op().device
    # maximum number of splits possible in the whole tree =2^(D-1)-1
    # TODO(youngheek): perhaps storage could be optimized by storing stats with
    # the dimension max_splits_per_layer, instead of max_splits (for the entire
    # tree).
    max_splits = (1 << tree_hparams.max_depth) - 1
    train_op = []
    with ops.name_scope(name) as name:
        # Prepare.
        global_step = training_util.get_or_create_global_step()
        bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
            sorted_feature_columns)
        # Extract input features and set up cache for training.
        training_state_cache = None
        if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
            # cache transformed features as well for in-memory training.
            batch_size = array_ops.shape(labels)[0]
            input_feature_list, input_cache_op = (_cache_transformed_features(
                features, sorted_feature_columns, batch_size))
            train_op.append(input_cache_op)
            training_state_cache = _CacheTrainingStatesUsingVariables(
                batch_size, head.logits_dimension)
        else:
            input_feature_list = _get_transformed_features(
                features, sorted_feature_columns)
            if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
                example_ids = features[example_id_column_name]
                training_state_cache = _CacheTrainingStatesUsingHashTable(
                    example_ids, head.logits_dimension)

        # Create Ensemble resources.
        tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
        # Create logits.
        if mode != model_fn.ModeKeys.TRAIN:
            logits = boosted_trees_ops.predict(
                # For non-TRAIN mode, ensemble doesn't change after initialization,
                # so no local copy is needed; using tree_ensemble directly.
                tree_ensemble_handle=tree_ensemble.resource_handle,
                bucketized_features=input_feature_list,
                logits_dimension=head.logits_dimension)
        else:
            if is_single_machine:
                local_tree_ensemble = tree_ensemble
                ensemble_reload = control_flow_ops.no_op()
            else:
                # Have a local copy of ensemble for the distributed setting.
                with ops.device(worker_device):
                    local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
                        name=name + '_local', is_local=True)
                # TODO(soroush): Do partial updates if this becomes a bottleneck.
                ensemble_reload = local_tree_ensemble.deserialize(
                    *tree_ensemble.serialize())
            if training_state_cache:
                cached_tree_ids, cached_node_ids, cached_logits = (
                    training_state_cache.lookup())
            else:
                # Always start from the beginning when no cache is set up.
                batch_size = array_ops.shape(labels)[0]
                cached_tree_ids, cached_node_ids, cached_logits = (
                    array_ops.zeros([batch_size], dtype=dtypes.int32),
                    array_ops.zeros([batch_size], dtype=dtypes.int32),
                    array_ops.zeros([batch_size, head.logits_dimension],
                                    dtype=dtypes.float32))
            with ops.control_dependencies([ensemble_reload]):
                (stamp_token, num_trees, num_finalized_trees,
                 num_attempted_layers,
                 last_layer_nodes_range) = local_tree_ensemble.get_states()
                summary.scalar('ensemble/num_trees', num_trees)
                summary.scalar('ensemble/num_finalized_trees',
                               num_finalized_trees)
                summary.scalar('ensemble/num_attempted_layers',
                               num_attempted_layers)

                partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
                    tree_ensemble_handle=local_tree_ensemble.resource_handle,
                    cached_tree_ids=cached_tree_ids,
                    cached_node_ids=cached_node_ids,
                    bucketized_features=input_feature_list,
                    logits_dimension=head.logits_dimension)
            logits = cached_logits + partial_logits

        # Create training graph.
        def _train_op_fn(loss):
            """Run one training iteration."""
            if training_state_cache:
                train_op.append(
                    training_state_cache.insert(tree_ids, node_ids, logits))
            if closed_form_grad_and_hess_fn:
                gradients, hessians = closed_form_grad_and_hess_fn(
                    logits, labels)
            else:
                gradients = gradients_impl.gradients(loss,
                                                     logits,
                                                     name='Gradients')[0]
                hessians = gradients_impl.gradients(gradients,
                                                    logits,
                                                    name='Hessians')[0]

            stats_summaries_list = []
            for i, feature_ids in enumerate(feature_ids_list):
                num_buckets = bucket_size_list[i]
                summaries = [
                    array_ops.squeeze(boosted_trees_ops.make_stats_summary(
                        node_ids=node_ids,
                        gradients=gradients,
                        hessians=hessians,
                        bucketized_features_list=[input_feature_list[f]],
                        max_splits=max_splits,
                        num_buckets=num_buckets),
                                      axis=0) for f in feature_ids
                ]
                stats_summaries_list.append(summaries)

            accumulators = []

            def grow_tree_from_stats_summaries(stats_summaries_list,
                                               feature_ids_list):
                """Updates ensemble based on the best gains from stats summaries."""
                node_ids_per_feature = []
                gains_list = []
                thresholds_list = []
                left_node_contribs_list = []
                right_node_contribs_list = []
                all_feature_ids = []

                assert len(stats_summaries_list) == len(feature_ids_list)

                for i, feature_ids in enumerate(feature_ids_list):
                    (numeric_node_ids_per_feature, numeric_gains_list,
                     numeric_thresholds_list, numeric_left_node_contribs_list,
                     numeric_right_node_contribs_list) = (
                         boosted_trees_ops.calculate_best_gains_per_feature(
                             node_id_range=last_layer_nodes_range,
                             stats_summary_list=stats_summaries_list[i],
                             l1=tree_hparams.l1,
                             l2=tree_hparams.l2,
                             tree_complexity=tree_hparams.tree_complexity,
                             min_node_weight=tree_hparams.min_node_weight,
                             max_splits=max_splits))

                    all_feature_ids += feature_ids
                    node_ids_per_feature += numeric_node_ids_per_feature
                    gains_list += numeric_gains_list
                    thresholds_list += numeric_thresholds_list
                    left_node_contribs_list += numeric_left_node_contribs_list
                    right_node_contribs_list += numeric_right_node_contribs_list

                grow_op = boosted_trees_ops.update_ensemble(
                    # Confirm if local_tree_ensemble or tree_ensemble should be used.
                    tree_ensemble.resource_handle,
                    feature_ids=all_feature_ids,
                    node_ids=node_ids_per_feature,
                    gains=gains_list,
                    thresholds=thresholds_list,
                    left_node_contribs=left_node_contribs_list,
                    right_node_contribs=right_node_contribs_list,
                    learning_rate=tree_hparams.learning_rate,
                    max_depth=tree_hparams.max_depth,
                    pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
                return grow_op

            if train_in_memory and is_single_machine:
                train_op.append(distribute_lib.increment_var(global_step))
                train_op.append(
                    grow_tree_from_stats_summaries(stats_summaries_list,
                                                   feature_ids_list))
            else:
                dependencies = []

                for i, feature_ids in enumerate(feature_ids_list):
                    stats_summaries = stats_summaries_list[i]
                    accumulator = data_flow_ops.ConditionalAccumulator(
                        dtype=dtypes.float32,
                        # The stats consist of grads and hessians (the last dimension).
                        shape=[
                            len(feature_ids), max_splits, bucket_size_list[i],
                            2
                        ],
                        shared_name='numeric_stats_summary_accumulator_' +
                        str(i))
                    accumulators.append(accumulator)

                    apply_grad = accumulator.apply_grad(
                        array_ops.stack(stats_summaries, axis=0), stamp_token)
                    dependencies.append(apply_grad)

                def grow_tree_from_accumulated_summaries_fn():
                    """Updates the tree with the best layer from accumulated summaries."""
                    # Take out the accumulated summaries from the accumulator and grow.
                    stats_summaries_list = []

                    stats_summaries_list = [
                        array_ops.unstack(accumulator.take_grad(1), axis=0)
                        for accumulator in accumulators
                    ]

                    grow_op = grow_tree_from_stats_summaries(
                        stats_summaries_list, feature_ids_list)
                    return grow_op

                with ops.control_dependencies(dependencies):
                    train_op.append(distribute_lib.increment_var(global_step))
                    if config.is_chief:
                        min_accumulated = math_ops.reduce_min(
                            array_ops.stack([
                                acc.num_accumulated() for acc in accumulators
                            ]))

                        train_op.append(
                            control_flow_ops.cond(
                                math_ops.greater_equal(min_accumulated,
                                                       n_batches_per_layer),
                                grow_tree_from_accumulated_summaries_fn,
                                control_flow_ops.no_op,
                                name='wait_until_n_batches_accumulated'))

            return control_flow_ops.group(train_op, name='train_op')

    estimator_spec = head.create_estimator_spec(features=features,
                                                mode=mode,
                                                labels=labels,
                                                train_op_fn=_train_op_fn,
                                                logits=logits)
    if mode == model_fn.ModeKeys.TRAIN:
        # Add an early stop hook.
        estimator_spec = estimator_spec._replace(
            training_hooks=estimator_spec.training_hooks +
            (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
                                 tree_hparams.n_trees, tree_hparams.max_depth),
             ))
    return estimator_spec
コード例 #18
0
  def testCachedPredictionTheWholeTreeWasPruned(self):
    """Tests that prediction based on previous node in the tree works."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            leaf {
              scalar: 0.00
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: -6.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 5.0
          }
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble.
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      cached_tree_ids = [
          0,
          0,
      ]
      # The predictions were cached in 1 and 2, both were pruned to the root.
      cached_node_ids = [1, 2]

      # We have two features: 0 and 1.These are not going to be used anywhere.
      feature_0_values = [12, 17]
      feature_1_values = [12, 12]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # We are in the last tree.
      self.assertAllClose([0, 0], new_tree_ids)
      self.assertAllClose([0, 0], new_node_ids)

      self.assertAllClose([[-6.0], [5.0]], logits_updates)
コード例 #19
0
ファイル: boosted_trees.py プロジェクト: zanes2016/tensorflow
def _bt_model_fn(
        features,
        labels,
        mode,
        head,
        feature_columns,
        tree_hparams,
        n_batches_per_layer,
        config,
        closed_form_grad_and_hess_fn=None,
        example_id_column_name=None,
        # TODO(youngheek): replace this later using other options.
        train_in_memory=False,
        name='TreeEnsembleModel'):
    """Gradient Boosted Decision Tree model_fn.

  Args:
    features: dict of `Tensor`.
    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
      dtype `int32` or `int64` in the range `[0, n_classes)`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    head: A `head_lib._Head` instance.
    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
    tree_hparams: TODO. collections.namedtuple for hyper parameters.
    n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
      least n_batches_per_layer accumulations.
    config: `RunConfig` object to configure the runtime settings.
    closed_form_grad_and_hess_fn: a function that accepts logits and labels
      and returns gradients and hessians. By default, they are created by
      tf.gradients() from the loss.
    example_id_column_name: Name of the feature for a unique ID per example.
      Currently experimental -- not exposed to public API.
    train_in_memory: `bool`, when true, it assumes the dataset is in memory,
      i.e., input_fn should return the entire dataset as a single batch, and
      also n_batches_per_layer should be set as 1.
    name: Name to use for the model.

  Returns:
      An `EstimatorSpec` instance.

  Raises:
    ValueError: mode or params are invalid, or features has the wrong type.
  """
    is_single_machine = (config.num_worker_replicas == 1)
    if train_in_memory:
        assert n_batches_per_layer == 1, (
            'When train_in_memory is enabled, input_fn should return the entire '
            'dataset as a single batch, and n_batches_per_layer should be set as '
            '1.')
    worker_device = control_flow_ops.no_op().device
    # maximum number of splits possible in the whole tree =2^(D-1)-1
    # TODO(youngheek): perhaps storage could be optimized by storing stats with
    # the dimension max_splits_per_layer, instead of max_splits (for the entire
    # tree).
    max_splits = (1 << tree_hparams.max_depth) - 1
    with ops.name_scope(name) as name:
        # Prepare.
        global_step = training_util.get_or_create_global_step()
        input_feature_list, num_buckets = _get_transformed_features(
            features, feature_columns)
        if train_in_memory and mode == model_fn.ModeKeys.TRAIN:
            input_feature_list = [
                _keep_as_local_variable(feature)
                for feature in input_feature_list
            ]
        num_features = len(input_feature_list)

        cache = None
        if mode == model_fn.ModeKeys.TRAIN:
            if train_in_memory and is_single_machine:  # maybe just train_in_memory?
                batch_size = array_ops.shape(input_feature_list[0])[0]
                cache = _CacheTrainingStatesUsingVariables(
                    batch_size, head.logits_dimension)
            elif example_id_column_name:
                example_ids = features[example_id_column_name]
                cache = _CacheTrainingStatesUsingHashTable(
                    example_ids, head.logits_dimension)

        # Create Ensemble resources.
        if is_single_machine:
            tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
            local_tree_ensemble = tree_ensemble
            ensemble_reload = control_flow_ops.no_op()
        else:
            tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
            with ops.device(worker_device):
                local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
                    name=name + '_local', is_local=True)
            # TODO(soroush): Do partial updates if this becomes a bottleneck.
            ensemble_reload = local_tree_ensemble.deserialize(
                *tree_ensemble.serialize())

        # Create logits.
        if mode != model_fn.ModeKeys.TRAIN:
            logits = boosted_trees_ops.predict(
                tree_ensemble_handle=local_tree_ensemble.resource_handle,
                bucketized_features=input_feature_list,
                logits_dimension=head.logits_dimension,
                max_depth=tree_hparams.max_depth)
        else:
            if cache:
                cached_tree_ids, cached_node_ids, cached_logits = cache.lookup(
                )
            else:
                # Always start from the beginning when no cache is set up.
                batch_size = array_ops.shape(input_feature_list[0])[0]
                cached_tree_ids, cached_node_ids, cached_logits = (
                    array_ops.zeros([batch_size], dtype=dtypes.int32),
                    array_ops.zeros([batch_size], dtype=dtypes.int32),
                    array_ops.zeros([batch_size, head.logits_dimension],
                                    dtype=dtypes.float32))
            with ops.control_dependencies([ensemble_reload]):
                (stamp_token, num_trees, num_finalized_trees,
                 num_attempted_layers) = local_tree_ensemble.get_states()
                summary.scalar('ensemble/num_trees', num_trees)
                summary.scalar('ensemble/num_finalized_trees',
                               num_finalized_trees)
                summary.scalar('ensemble/num_attempted_layers',
                               num_attempted_layers)

                partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
                    tree_ensemble_handle=local_tree_ensemble.resource_handle,
                    cached_tree_ids=cached_tree_ids,
                    cached_node_ids=cached_node_ids,
                    bucketized_features=input_feature_list,
                    logits_dimension=head.logits_dimension,
                    max_depth=tree_hparams.max_depth)
            logits = cached_logits + partial_logits

        # Create training graph.
        def _train_op_fn(loss):
            """Run one training iteration."""
            train_op = []
            if cache:
                train_op.append(cache.insert(tree_ids, node_ids, logits))
            if closed_form_grad_and_hess_fn:
                gradients, hessians = closed_form_grad_and_hess_fn(
                    logits, labels)
            else:
                gradients = gradients_impl.gradients(loss,
                                                     logits,
                                                     name='Gradients')[0]
                hessians = gradients_impl.gradients(gradients,
                                                    logits,
                                                    name='Hessians')[0]
            stats_summary_list = [
                array_ops.squeeze(boosted_trees_ops.make_stats_summary(
                    node_ids=node_ids,
                    gradients=gradients,
                    hessians=hessians,
                    bucketized_features_list=[input_feature_list[f]],
                    max_splits=max_splits,
                    num_buckets=num_buckets),
                                  axis=0) for f in range(num_features)
            ]

            def grow_tree_from_stats_summaries(stats_summary_list):
                """Updates ensemble based on the best gains from stats summaries."""
                (node_ids_per_feature, gains_list, thresholds_list,
                 left_node_contribs_list, right_node_contribs_list) = (
                     boosted_trees_ops.calculate_best_gains_per_feature(
                         node_id_range=array_ops.stack([
                             math_ops.reduce_min(node_ids),
                             math_ops.reduce_max(node_ids)
                         ]),
                         stats_summary_list=stats_summary_list,
                         l1=tree_hparams.l1,
                         l2=tree_hparams.l2,
                         tree_complexity=tree_hparams.tree_complexity,
                         max_splits=max_splits))
                grow_op = boosted_trees_ops.update_ensemble(
                    # Confirm if local_tree_ensemble or tree_ensemble should be used.
                    tree_ensemble.resource_handle,
                    feature_ids=math_ops.range(0,
                                               num_features,
                                               dtype=dtypes.int32),
                    node_ids=node_ids_per_feature,
                    gains=gains_list,
                    thresholds=thresholds_list,
                    left_node_contribs=left_node_contribs_list,
                    right_node_contribs=right_node_contribs_list,
                    learning_rate=tree_hparams.learning_rate,
                    max_depth=tree_hparams.max_depth,
                    pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
                return grow_op

            if train_in_memory and is_single_machine:
                train_op.append(state_ops.assign_add(global_step, 1))
                train_op.append(
                    grow_tree_from_stats_summaries(stats_summary_list))
            else:
                summary_accumulator = data_flow_ops.ConditionalAccumulator(
                    dtype=dtypes.float32,
                    # The stats consist of gradients and hessians (the last dimension).
                    shape=[num_features, max_splits, num_buckets, 2],
                    shared_name='stats_summary_accumulator')
                apply_grad = summary_accumulator.apply_grad(
                    array_ops.stack(stats_summary_list, axis=0), stamp_token)

                def grow_tree_from_accumulated_summaries_fn():
                    """Updates the tree with the best layer from accumulated summaries."""
                    # Take out the accumulated summaries from the accumulator and grow.
                    stats_summary_list = array_ops.unstack(
                        summary_accumulator.take_grad(1), axis=0)
                    grow_op = grow_tree_from_stats_summaries(
                        stats_summary_list)
                    return grow_op

                with ops.control_dependencies([apply_grad]):
                    train_op.append(state_ops.assign_add(global_step, 1))
                    if config.is_chief:
                        train_op.append(
                            control_flow_ops.cond(
                                math_ops.greater_equal(
                                    summary_accumulator.num_accumulated(),
                                    n_batches_per_layer),
                                grow_tree_from_accumulated_summaries_fn,
                                control_flow_ops.no_op,
                                name='wait_until_n_batches_accumulated'))

            return control_flow_ops.group(train_op, name='train_op')

    estimator_spec = head.create_estimator_spec(features=features,
                                                mode=mode,
                                                labels=labels,
                                                train_op_fn=_train_op_fn,
                                                logits=logits)
    if mode == model_fn.ModeKeys.TRAIN:
        # Add an early stop hook.
        estimator_spec = estimator_spec._replace(
            training_hooks=estimator_spec.training_hooks +
            (StopAtNumTreesHook(num_trees, tree_hparams.n_trees), ))
    return estimator_spec