예제 #1
0
    def testCachedPredictionOnEmptyEnsemble(self):
        """Tests that prediction on a dummy ensemble does not fail."""
        with self.test_session() as session:
            # Create a dummy ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble',
                                                           serialized_proto='')
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # No previous cached values.
            cached_tree_ids = [0, 0]
            cached_node_ids = [0, 0]

            # We have two features: 0 and 1. Values don't matter here on a dummy
            # ensemble.
            feature_0_values = [67, 5]
            feature_1_values = [9, 17]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # Nothing changed.
            self.assertAllClose(cached_tree_ids, new_tree_ids)
            self.assertAllClose(cached_node_ids, new_node_ids)
            self.assertAllClose([[0], [0]], logits_updates)
예제 #2
0
    def testContribsForOnlyABiasNode(self):
        """Tests case when, after training, only left with a bias node.

    For example, this could happen if the final ensemble contains one tree that
    got pruned up to the root.
    """
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 1.72
            }
          }
        }
        tree_weights: 0.1
        tree_metadata: {
          num_layers_grown: 0
        }
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # All features are unused.
            feature_0_values = [36, 32]
            feature_1_values = [13, -29]
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            bias = 1.72 * 0.1  # Root node of tree_0.
            expected_feature_ids = ((), ())
            expected_logits_paths = ((bias, ), (bias, ))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)
예제 #3
0
 def testCreate(self):
   with self.test_session():
     ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
     resources.initialize_resources(resources.shared_resources()).run()
     stamp_token = ensemble.get_stamp_token()
     self.assertEqual(0, stamp_token.eval())
     (_, num_trees, num_finalized_trees,
      num_attempted_layers) = ensemble.get_states()
     self.assertEqual(0, num_trees.eval())
     self.assertEqual(0, num_finalized_trees.eval())
     self.assertEqual(0, num_attempted_layers.eval())
예제 #4
0
 def testCreate(self):
     with self.cached_session():
         ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
         resources.initialize_resources(resources.shared_resources()).run()
         stamp_token = ensemble.get_stamp_token()
         self.assertEqual(0, self.evaluate(stamp_token))
         (_, num_trees, num_finalized_trees, num_attempted_layers,
          nodes_range) = ensemble.get_states()
         self.assertEqual(0, self.evaluate(num_trees))
         self.assertEqual(0, self.evaluate(num_finalized_trees))
         self.assertEqual(0, self.evaluate(num_attempted_layers))
         self.assertAllEqual([0, 1], self.evaluate(nodes_range))
예제 #5
0
    def testPredictionOnEmptyEnsemble(self):
        """Tests that prediction on a empty ensemble does not fail."""
        with self.cached_session() as session:
            # Create an empty ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble',
                                                           serialized_proto='')
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [11, 27]
            expected_logits = [[0.0], [0.0]]

            # Prediction should work fine.
            predict_op = boosted_trees_ops.predict(
                tree_ensemble_handle,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits = session.run(predict_op)
            self.assertAllClose(expected_logits, logits)
예제 #6
0
    def testPredictionMultipleTreeMultiClass(self):
        """Tests the predictions work when we have multiple trees."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              vector: {
                value: 0.51
              }
              vector: {
                value: 1.14
              }
            }
          }
          nodes {
            leaf {
              vector: {
                value: 1.29
              }
              vector: {
                value: 8.79
              }
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              vector: {
                value: -4.33
              }
              vector: {
                value: 7.0
              }
            }
          }
          nodes {
            leaf {
              vector: {
                value: 0.2
              }
              vector: {
                value: 5.0
              }
            }
          }
          nodes {
            leaf {
              vector: {
                value: -4.1
              }
              vector: {
                value: 6.0
              }
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              vector: {
                value: 2.0
              }
              vector: {
                value: -7.0
              }
            }
          }
          nodes {
            leaf {
              vector: {
                value: 6.3
              }
              vector: {
                value: 5.0
              }
            }
          }
        }
        tree_weights: 0.1
        tree_weights: 0.2
        tree_weights: 1.0
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [11, 27]

            # Example 1: tree 0: (0.51, 1.14), tree 1: (0.2, 5.0), tree 2: (6.3, 5.0)
            #
            #            logits = (0.1*0.51+0.2*0.2+1*6.3,
            #                      0.1*1.14+0.2*5.0+1*5)
            # Example 2: tree 0: (0.51, 1.14), tree 1: (-4.33, 7.0), tree 2: (2.0, -7)
            #
            #            logits = (0.1*0.51+0.2*-4.33+1*2.0,
            #                      0.1*1.14+0.2*7.0+1*-7)
            logits_dimension = 2
            expected_logits = [[6.391, 6.114], [1.185, -5.486]]

            # Prediction should work fine.
            predict_op = boosted_trees_ops.predict(
                tree_ensemble_handle,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=logits_dimension)

            logits = session.run(predict_op)
            self.assertAllClose(expected_logits, logits)
예제 #7
0
  def testMetadataWhenCantSplitDueToEmptySplits(self):
    """Test that the metadata is updated even though we can't split."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 0.714
            }
          }
          nodes {
            leaf {
              scalar: -0.4375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.
      # feature 1 only has a candidate for node 1, feature 2 has candidates
      # for both nodes and feature 3 only has a candidate for node 2.

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
          max_depth=2,
          # No splits are available.
          feature_ids=[],
          node_ids=[],
          gains=[],
          thresholds=[],
          left_node_contribs=[],
          right_node_contribs=[])
      session.run(grow_op)

      # Expect no new splits created, but attempted (global) stats updated. Meta
      # data for this tree should not be updated (we didn't succeed building a
      # layer.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 0.714
            }
          }
          nodes {
            leaf {
              scalar: -0.4375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
예제 #8
0
    def testCachedPredictionTheWholeTreeWasPruned(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 0.00
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: -6.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 5.0
          }
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

            # Create existing ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            cached_tree_ids = [
                0,
                0,
            ]
            # The predictions were cached in 1 and 2, both were pruned to the root.
            cached_node_ids = [1, 2]

            # We have two features: 0 and 1.These are not going to be used anywhere.
            feature_0_values = [12, 17]
            feature_1_values = [12, 12]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the last tree.
            self.assertAllClose([0, 0], new_tree_ids)
            self.assertAllClose([0, 0], new_node_ids)

            self.assertAllClose([[-6.0], [5.0]], logits_updates)
예제 #9
0
def _bt_model_fn(
        features,
        labels,
        mode,
        head,
        feature_columns,
        tree_hparams,
        n_batches_per_layer,
        config,
        closed_form_grad_and_hess_fn=None,
        example_id_column_name=None,
        # TODO(youngheek): replace this later using other options.
        train_in_memory=False,
        name='TreeEnsembleModel'):
    """Gradient Boosted Decision Tree model_fn.

  Args:
    features: dict of `Tensor`.
    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
      dtype `int32` or `int64` in the range `[0, n_classes)`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    head: A `head_lib._Head` instance.
    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
    tree_hparams: TODO. collections.namedtuple for hyper parameters.
    n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
      least n_batches_per_layer accumulations.
    config: `RunConfig` object to configure the runtime settings.
    closed_form_grad_and_hess_fn: a function that accepts logits and labels
      and returns gradients and hessians. By default, they are created by
      tf.gradients() from the loss.
    example_id_column_name: Name of the feature for a unique ID per example.
      Currently experimental -- not exposed to public API.
    train_in_memory: `bool`, when true, it assumes the dataset is in memory,
      i.e., input_fn should return the entire dataset as a single batch, and
      also n_batches_per_layer should be set as 1.
    name: Name to use for the model.

  Returns:
      An `EstimatorSpec` instance.

  Raises:
    ValueError: mode or params are invalid, or features has the wrong type.
  """
    is_single_machine = (config.num_worker_replicas == 1)
    if train_in_memory:
        assert n_batches_per_layer == 1, (
            'When train_in_memory is enabled, input_fn should return the entire '
            'dataset as a single batch, and n_batches_per_layer should be set as '
            '1.')
    worker_device = control_flow_ops.no_op().device
    # maximum number of splits possible in the whole tree =2^(D-1)-1
    # TODO(youngheek): perhaps storage could be optimized by storing stats with
    # the dimension max_splits_per_layer, instead of max_splits (for the entire
    # tree).
    max_splits = (1 << tree_hparams.max_depth) - 1
    with ops.name_scope(name) as name:
        # Prepare.
        global_step = training_util.get_or_create_global_step()
        input_feature_list, num_buckets = _get_transformed_features(
            features, feature_columns)
        if train_in_memory and mode == model_fn.ModeKeys.TRAIN:
            input_feature_list = [
                _keep_as_local_variable(feature)
                for feature in input_feature_list
            ]
        num_features = len(input_feature_list)

        cache = None
        if mode == model_fn.ModeKeys.TRAIN:
            if train_in_memory and is_single_machine:  # maybe just train_in_memory?
                batch_size = array_ops.shape(input_feature_list[0])[0]
                cache = _CacheTrainingStatesUsingVariables(
                    batch_size, head.logits_dimension)
            elif example_id_column_name:
                example_ids = features[example_id_column_name]
                cache = _CacheTrainingStatesUsingHashTable(
                    example_ids, head.logits_dimension)

        # Create Ensemble resources.
        if is_single_machine:
            tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
            local_tree_ensemble = tree_ensemble
            ensemble_reload = control_flow_ops.no_op()
        else:
            tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
            with ops.device(worker_device):
                local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
                    name=name + '_local', is_local=True)
            # TODO(soroush): Do partial updates if this becomes a bottleneck.
            ensemble_reload = local_tree_ensemble.deserialize(
                *tree_ensemble.serialize())

        # Create logits.
        if mode != model_fn.ModeKeys.TRAIN:
            logits = boosted_trees_ops.predict(
                tree_ensemble_handle=local_tree_ensemble.resource_handle,
                bucketized_features=input_feature_list,
                logits_dimension=head.logits_dimension,
                max_depth=tree_hparams.max_depth)
        else:
            if cache:
                cached_tree_ids, cached_node_ids, cached_logits = cache.lookup(
                )
            else:
                # Always start from the beginning when no cache is set up.
                batch_size = array_ops.shape(input_feature_list[0])[0]
                cached_tree_ids, cached_node_ids, cached_logits = (
                    array_ops.zeros([batch_size], dtype=dtypes.int32),
                    array_ops.zeros([batch_size], dtype=dtypes.int32),
                    array_ops.zeros([batch_size, head.logits_dimension],
                                    dtype=dtypes.float32))
            with ops.control_dependencies([ensemble_reload]):
                (stamp_token, num_trees, num_finalized_trees,
                 num_attempted_layers) = local_tree_ensemble.get_states()
                summary.scalar('ensemble/num_trees', num_trees)
                summary.scalar('ensemble/num_finalized_trees',
                               num_finalized_trees)
                summary.scalar('ensemble/num_attempted_layers',
                               num_attempted_layers)

                partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
                    tree_ensemble_handle=local_tree_ensemble.resource_handle,
                    cached_tree_ids=cached_tree_ids,
                    cached_node_ids=cached_node_ids,
                    bucketized_features=input_feature_list,
                    logits_dimension=head.logits_dimension,
                    max_depth=tree_hparams.max_depth)
            logits = cached_logits + partial_logits

        # Create training graph.
        def _train_op_fn(loss):
            """Run one training iteration."""
            train_op = []
            if cache:
                train_op.append(cache.insert(tree_ids, node_ids, logits))
            if closed_form_grad_and_hess_fn:
                gradients, hessians = closed_form_grad_and_hess_fn(
                    logits, labels)
            else:
                gradients = gradients_impl.gradients(loss,
                                                     logits,
                                                     name='Gradients')[0]
                hessians = gradients_impl.gradients(gradients,
                                                    logits,
                                                    name='Hessians')[0]
            stats_summary_list = [
                array_ops.squeeze(boosted_trees_ops.make_stats_summary(
                    node_ids=node_ids,
                    gradients=gradients,
                    hessians=hessians,
                    bucketized_features_list=[input_feature_list[f]],
                    max_splits=max_splits,
                    num_buckets=num_buckets),
                                  axis=0) for f in range(num_features)
            ]

            def grow_tree_from_stats_summaries(stats_summary_list):
                """Updates ensemble based on the best gains from stats summaries."""
                (node_ids_per_feature, gains_list, thresholds_list,
                 left_node_contribs_list, right_node_contribs_list) = (
                     boosted_trees_ops.calculate_best_gains_per_feature(
                         node_id_range=array_ops.stack([
                             math_ops.reduce_min(node_ids),
                             math_ops.reduce_max(node_ids)
                         ]),
                         stats_summary_list=stats_summary_list,
                         l1=tree_hparams.l1,
                         l2=tree_hparams.l2,
                         tree_complexity=tree_hparams.tree_complexity,
                         max_splits=max_splits))
                grow_op = boosted_trees_ops.update_ensemble(
                    # Confirm if local_tree_ensemble or tree_ensemble should be used.
                    tree_ensemble.resource_handle,
                    feature_ids=math_ops.range(0,
                                               num_features,
                                               dtype=dtypes.int32),
                    node_ids=node_ids_per_feature,
                    gains=gains_list,
                    thresholds=thresholds_list,
                    left_node_contribs=left_node_contribs_list,
                    right_node_contribs=right_node_contribs_list,
                    learning_rate=tree_hparams.learning_rate,
                    max_depth=tree_hparams.max_depth,
                    pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
                return grow_op

            if train_in_memory and is_single_machine:
                train_op.append(state_ops.assign_add(global_step, 1))
                train_op.append(
                    grow_tree_from_stats_summaries(stats_summary_list))
            else:
                summary_accumulator = data_flow_ops.ConditionalAccumulator(
                    dtype=dtypes.float32,
                    # The stats consist of gradients and hessians (the last dimension).
                    shape=[num_features, max_splits, num_buckets, 2],
                    shared_name='stats_summary_accumulator')
                apply_grad = summary_accumulator.apply_grad(
                    array_ops.stack(stats_summary_list, axis=0), stamp_token)

                def grow_tree_from_accumulated_summaries_fn():
                    """Updates the tree with the best layer from accumulated summaries."""
                    # Take out the accumulated summaries from the accumulator and grow.
                    stats_summary_list = array_ops.unstack(
                        summary_accumulator.take_grad(1), axis=0)
                    grow_op = grow_tree_from_stats_summaries(
                        stats_summary_list)
                    return grow_op

                with ops.control_dependencies([apply_grad]):
                    train_op.append(state_ops.assign_add(global_step, 1))
                    if config.is_chief:
                        train_op.append(
                            control_flow_ops.cond(
                                math_ops.greater_equal(
                                    summary_accumulator.num_accumulated(),
                                    n_batches_per_layer),
                                grow_tree_from_accumulated_summaries_fn,
                                control_flow_ops.no_op,
                                name='wait_until_n_batches_accumulated'))

            return control_flow_ops.group(train_op, name='train_op')

    estimator_spec = head.create_estimator_spec(features=features,
                                                mode=mode,
                                                labels=labels,
                                                train_op_fn=_train_op_fn,
                                                logits=logits)
    if mode == model_fn.ModeKeys.TRAIN:
        # Add an early stop hook.
        estimator_spec = estimator_spec._replace(
            training_hooks=estimator_spec.training_hooks +
            (StopAtNumTreesHook(num_trees, tree_hparams.n_trees), ))
    return estimator_spec
예제 #10
0
  def testPostPruningOfAllNodes(self):
    """Test growing an ensemble with post-pruning, with all nodes are pruned."""
    with self.test_session() as session:
      # Create empty ensemble.
      # Create empty ensemble.
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle

      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare inputs. All have negative gains.
      feature_ids = [0, 1]

      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([-1.3], dtype=np.float32)
      feature1_thresholds = np.array([7], dtype=np.int32)
      feature1_left_node_contribs = np.array([[0.013]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[0.0143]], dtype=np.float32)

      feature2_nodes = np.array([0], dtype=np.int32)
      feature2_gains = np.array([-0.62], dtype=np.float32)
      feature2_thresholds = np.array([33], dtype=np.int32)
      feature2_left_node_contribs = np.array([[0.01]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.0143]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=2,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes],
          gains=[feature1_gains, feature2_gains],
          thresholds=[feature1_thresholds, feature2_thresholds],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs
          ])

      session.run(grow_op)

      # Expect the split from feature 2 to be chosen despite the negative gain.
      # The grown tree should not be finalized as max tree depth is 2 so no
      # pruning occurs.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.62
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            leaf {
              scalar: 0.0143
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, res_ensemble)

      # Prepare inputs.
      # All have negative gain.
      feature_ids = [3]
      feature1_nodes = np.array([1, 2], dtype=np.int32)
      feature1_gains = np.array([-0.2, -0.5], dtype=np.float32)
      feature1_thresholds = np.array([77, 79], dtype=np.int32)
      feature1_left_node_contribs = np.array([[0.023], [0.3]], dtype=np.float32)
      feature1_right_node_contribs = np.array(
          [[0.012343], [24]], dtype=np.float32)

      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=2,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes],
          gains=[feature1_gains],
          thresholds=[feature1_thresholds],
          left_node_contribs=[feature1_left_node_contribs],
          right_node_contribs=[feature1_right_node_contribs])

      session.run(grow_op)

      # Expect the split from feature 1 to be chosen despite the negative gain.
      # The grown tree should be finalized. Since all nodes have negative gain,
      # the whole tree is pruned.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)

      # Expect the ensemble to be empty as post-pruning will prune
      # the entire finalized tree.
      self.assertEqual(new_stamp, 2)
      self.assertProtoEquals("""
      trees {
        nodes {
          leaf {
          }
        }
      }
      trees {
        nodes {
          leaf {
          }
        }
      }
      tree_weights: 1.0
      tree_weights: 1.0
      tree_metadata{
        num_layers_grown: 2
        is_finalized: true
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: 0.0
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.01
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.0143
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.033
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.022343
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.3143
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -24.0143
        }
      }
      tree_metadata {
      }
      growing_metadata {
        num_trees_attempted: 1
        num_layers_attempted: 2
      }
      """, res_ensemble)
예제 #11
0
    def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id:0
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 5
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 0.5
              original_leaf {
                scalar: 0.0143
               }
            }
          }
          nodes {
            leaf {
              scalar: 0.0553
            }
          }
          nodes {
            leaf {
              scalar: 0.0783
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.55
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 3
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 2
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.07
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.083
          }
          post_pruned_nodes_meta {
            new_node_id: 3
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 4
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.22
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.57
          }
        }
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 2
          num_layers_attempted: 4
        }
      """, tree_ensemble_config)

            # Create existing ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            cached_tree_ids = [0, 0, 0, 0, 0, 0]
            # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
            # changed the ids to 3 and 4 respectively.
            cached_node_ids = [3, 4, 5, 6, 7, 8]

            # We have two features: 0 and 1.
            feature_0_values = [12, 17, 35, 36, 23, 11]
            feature_1_values = [12, 12, 17, 18, 123, 24]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the last tree.
            self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids)
            # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
            # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in
            # the root node.
            self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids)

            cached_values = [[0.08], [0.093], [0.0553], [0.0783],
                             [0.15 + 0.08], [0.5 + 0.08]]
            root = 0.55
            self.assertAllClose(
                [[root + 0.01], [root + 0.01], [root + 0.0553],
                 [root + 0.0783], [root + 0.01], [root + 0.01]],
                logits_updates + cached_values)
예제 #12
0
  def testPostPruningChangesNothing(self):
    """Test growing an ensemble with post-pruning with all gains >0."""
    with self.test_session() as session:
      # Create empty ensemble.
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle

      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare inputs.
      # Second feature has larger (but still negative gain).
      feature_ids = [3, 4]

      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([7.62], dtype=np.float32)
      feature1_thresholds = np.array([52], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)

      feature2_nodes = np.array([0], dtype=np.int32)
      feature2_gains = np.array([0.63], dtype=np.float32)
      feature2_thresholds = np.array([23], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=1,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes],
          gains=[feature1_gains, feature2_gains],
          thresholds=[feature1_thresholds, feature2_thresholds],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs
          ])

      session.run(grow_op)

      # Expect the split from the first feature to be chosen.
      # Pruning got triggered but changed nothing.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 3
              threshold: 52
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
          nodes {
            leaf {
              scalar: 7.143
            }
          }
        }
        trees {
          nodes {
            leaf {
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, res_ensemble)
예제 #13
0
    def testCachedPredictionIsCurrent(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 0.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 2]

            # We have two features: 0 and 1. Values don't matter because trees didn't
            # change.
            feature_0_values = [67, 5]
            feature_1_values = [9, 17]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # Nothing changed.
            self.assertAllClose(cached_tree_ids, new_tree_ids)
            self.assertAllClose(cached_node_ids, new_node_ids)
            self.assertAllClose([[0], [0]], logits_updates)
예제 #14
0
    def testContribsMultipleTree(self):
        """Tests that the contribs work when we have multiple trees."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf: {scalar: 2.1}
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
            metadata {
              original_leaf: {scalar: 5.5}
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_weights: 0.1
        tree_weights: 0.2
        tree_weights: 1.0
        tree_metadata: {
          num_layers_grown: 1}
        tree_metadata: {
          num_layers_grown: 2}
        tree_metadata: {
          num_layers_grown: 1}
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [13, -29
                                ]  # Unused. Feature is not in above ensemble.
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            bias = 2.1 * 0.1  # Root node of tree_0.
            expected_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
            # example_0 :  (bias, 0.1 * 1.14, 0.2 * 5.5 + .114, 0.2 * 5. + .114,
            # 1.0 * 5.0 + 0.2 * 5. + .114)
            # example_1 :  (bias, 0.1 * 1.14, 0.2 * 7 + .114,
            # 1.0 * -7. + 0.2 * 7 + .114)
            expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114),
                                     (bias, 0.114, 1.514, -5.486))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)
예제 #15
0
  def testPrePruning(self):
    """Test growing an existing ensemble with pre-pruning."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.
      # For node 1, the best split is on feature 2 (gain -0.63), but the gain
      # is negative so node 1 will not be split.
      # For node 2, the best split is on feature 3, gain is positive.

      feature_ids = [0, 1, 0]

      feature1_nodes = np.array([1], dtype=np.int32)
      feature1_gains = np.array([-1.4], dtype=np.float32)
      feature1_thresholds = np.array([21], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)

      feature2_nodes = np.array([1, 2], dtype=np.int32)
      feature2_gains = np.array([-0.63, 2.7], dtype=np.float32)
      feature2_thresholds = np.array([23, 7], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)

      feature3_nodes = np.array([2], dtype=np.int32)
      feature3_gains = np.array([2.8], dtype=np.float32)
      feature3_thresholds = np.array([3], dtype=np.int32)
      feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
      feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
          gains=[feature1_gains, feature2_gains, feature3_gains],
          thresholds=[
              feature1_thresholds, feature2_thresholds, feature3_thresholds
          ],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs,
              feature3_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs,
              feature3_right_node_contribs
          ])
      session.run(grow_op)

      # Expect the split for node 1 to be chosen from feature 1 and
      # the split for node 2 to be chosen from feature 2.
      # The grown tree should not be finalized as max tree depth is 3 and
      # it's only grown 2 layers.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 3
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 2.8
              original_leaf {
                scalar: -4.375
              }
            }
          }
          nodes {
            leaf {
              scalar: -4.45
            }
          }
          nodes {
            leaf {
              scalar: -4.182
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: false
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
예제 #16
0
  def testGrowExistingEnsembleTreeFinalized(self):
    """Test growing an existing ensemble with the last tree finalized."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.0
            }
          }
        }
        tree_weights: 0.15
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.

      feature_ids = [75]

      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([-1.4], dtype=np.float32)
      feature1_thresholds = np.array([21], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
          learning_rate=0.1,
          max_depth=2,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes],
          gains=[feature1_gains],
          thresholds=[feature1_thresholds],
          left_node_contribs=[feature1_left_node_contribs],
          right_node_contribs=[feature1_right_node_contribs])
      session.run(grow_op)

      # Expect a new tree added, with a split on feature 75
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
       trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 75
              threshold: 21
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -1.4
            }
          }
          nodes {
            leaf {
              scalar: -0.6
            }
          }
          nodes {
            leaf {
              scalar: 0.165
            }
          }
        }
        tree_weights: 0.15
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
        }
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 2
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
예제 #17
0
  def testGrowWithEmptyEnsemble(self):
    """Test growing an empty ensemble."""
    with self.test_session() as session:
      # Create empty ensemble.
      tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_ids = [0, 2, 6]

      # Prepare feature inputs.
      # Note that features 1 & 3 have the same gain but different splits.
      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([7.62], dtype=np.float32)
      feature1_thresholds = np.array([52], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)

      feature2_nodes = np.array([0], dtype=np.int32)
      feature2_gains = np.array([0.63], dtype=np.float32)
      feature2_thresholds = np.array([23], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24]], dtype=np.float32)

      # Feature split with the highest gain.
      feature3_nodes = np.array([0], dtype=np.int32)
      feature3_gains = np.array([7.65], dtype=np.float32)
      feature3_thresholds = np.array([7], dtype=np.int32)
      feature3_left_node_contribs = np.array([[-4.89]], dtype=np.float32)
      feature3_right_node_contribs = np.array([[5.3]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
          # Tree will be finalized now, since we will reach depth 1.
          max_depth=1,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
          gains=[feature1_gains, feature2_gains, feature3_gains],
          thresholds=[
              feature1_thresholds, feature2_thresholds, feature3_thresholds
          ],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs,
              feature3_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs,
              feature3_right_node_contribs
          ])
      session.run(grow_op)

      new_stamp, serialized = session.run(tree_ensemble.serialize())

      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      # Note that since the tree is finalized, we added a new dummy tree.
      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 6
              threshold: 7
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.65
            }
          }
          nodes {
            leaf {
              scalar: -0.489
            }
          }
          nodes {
            leaf {
              scalar: 0.53
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.0
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
예제 #18
0
  def testGrowExistingEnsembleTreeNotFinalized(self):
    """Test growing an existing ensemble with the last tree not finalized."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 0.714
            }
          }
          nodes {
            leaf {
              scalar: -0.4375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.
      # feature 1 only has a candidate for node 1, feature 2 has candidates
      # for both nodes and feature 3 only has a candidate for node 2.

      feature_ids = [0, 1, 0]

      feature1_nodes = np.array([1], dtype=np.int32)
      feature1_gains = np.array([1.4], dtype=np.float32)
      feature1_thresholds = np.array([21], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)

      feature2_nodes = np.array([1, 2], dtype=np.int32)
      feature2_gains = np.array([0.63, 2.7], dtype=np.float32)
      feature2_thresholds = np.array([23, 7], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)

      feature3_nodes = np.array([2], dtype=np.int32)
      feature3_gains = np.array([1.7], dtype=np.float32)
      feature3_thresholds = np.array([3], dtype=np.int32)
      feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
      feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
          # tree is going to be finalized now, since we reach depth 2.
          max_depth=2,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
          gains=[feature1_gains, feature2_gains, feature3_gains],
          thresholds=[
              feature1_thresholds, feature2_thresholds, feature3_thresholds
          ],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs,
              feature3_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs,
              feature3_right_node_contribs
          ])
      session.run(grow_op)

      # Expect the split for node 1 to be chosen from feature 1 and
      # the split for node 2 to be chosen from feature 2.
      # The grown tree should be finalized as max tree depth is 2 and we have
      # grown 2 layers.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            bucketized_split {
              threshold: 21
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 1.4
              original_leaf {
                scalar: 0.714
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 7
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 2.7
              original_leaf {
                scalar: -0.4375
              }
            }
          }
          nodes {
            leaf {
              scalar: 0.114
            }
          }
          nodes {
            leaf {
              scalar: 0.879
            }
          }
          nodes {
            leaf {
              scalar: -0.5875
            }
          }
          nodes {
            leaf {
              scalar: -0.2075
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.0
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
예제 #19
0
    def testSerializeDeserialize(self):
        with self.cached_session():
            # Initialize.
            ensemble = boosted_trees_ops.TreeEnsemble('ensemble',
                                                      stamp_token=5)
            resources.initialize_resources(resources.shared_resources()).run()
            (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
             nodes_range) = ensemble.get_states()
            self.assertEqual(5, self.evaluate(stamp_token))
            self.assertEqual(0, self.evaluate(num_trees))
            self.assertEqual(0, self.evaluate(num_finalized_trees))
            self.assertEqual(0, self.evaluate(num_attempted_layers))
            self.assertAllEqual([0, 1], self.evaluate(nodes_range))

            # Deserialize.
            ensemble_proto = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 75
              threshold: 21
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -1.4
            }
          }
          nodes {
            leaf {
              scalar: -0.6
            }
          }
          nodes {
            leaf {
              scalar: 0.165
            }
          }
        }
        tree_weights: 0.5
        tree_metadata {
          num_layers_grown: 4  # it's fake intentionally.
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 5
          last_layer_node_start: 3
          last_layer_node_end: 7
        }
      """, ensemble_proto)
            with ops.control_dependencies([
                    ensemble.deserialize(
                        stamp_token=3,
                        serialized_proto=ensemble_proto.SerializeToString())
            ]):
                (stamp_token, num_trees, num_finalized_trees,
                 num_attempted_layers, nodes_range) = ensemble.get_states()
            self.assertEqual(3, self.evaluate(stamp_token))
            self.assertEqual(1, self.evaluate(num_trees))
            # This reads from metadata, not really counting the layers.
            self.assertEqual(5, self.evaluate(num_attempted_layers))
            self.assertEqual(0, self.evaluate(num_finalized_trees))
            self.assertAllEqual([3, 7], self.evaluate(nodes_range))

            # Serialize.
            new_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
            new_stamp_token, new_serialized = ensemble.serialize()
            self.assertEqual(3, self.evaluate(new_stamp_token))
            new_ensemble_proto.ParseFromString(new_serialized.eval())
            self.assertProtoEquals(ensemble_proto, new_ensemble_proto)
예제 #20
0
  def testMetadataWhenCantSplitDuePrePruning(self):
    """Test metadata is updated correctly when no split due to prepruning."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.
      feature_ids = [0, 1, 0]

      # All the gains are negative.
      feature1_nodes = np.array([1], dtype=np.int32)
      feature1_gains = np.array([-1.4], dtype=np.float32)
      feature1_thresholds = np.array([21], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)

      feature2_nodes = np.array([1, 2], dtype=np.int32)
      feature2_gains = np.array([-0.63, -2.7], dtype=np.float32)
      feature2_thresholds = np.array([23, 7], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)

      feature3_nodes = np.array([2], dtype=np.int32)
      feature3_gains = np.array([-2.8], dtype=np.float32)
      feature3_thresholds = np.array([3], dtype=np.int32)
      feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
      feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
          gains=[feature1_gains, feature2_gains, feature3_gains],
          thresholds=[
              feature1_thresholds, feature2_thresholds, feature3_thresholds
          ],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs,
              feature3_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs,
              feature3_right_node_contribs
          ])
      session.run(grow_op)

      # Expect that no new split was created because all the gains were negative
      # Global metadata should be updated, tree metadata should not be updated.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
예제 #21
0
 def testCreateWithProto(self):
     with self.cached_session():
         ensemble_proto = boosted_trees_pb2.TreeEnsemble()
         text_format.Merge(
             """
     trees {
       nodes {
         bucketized_split {
           feature_id: 4
           left_id: 1
           right_id: 2
         }
         metadata {
           gain: 7.62
         }
       }
       nodes {
         bucketized_split {
           threshold: 21
           left_id: 3
           right_id: 4
         }
         metadata {
           gain: 1.4
           original_leaf {
             scalar: 7.14
           }
         }
       }
       nodes {
         bucketized_split {
           feature_id: 1
           threshold: 7
           left_id: 5
           right_id: 6
         }
         metadata {
           gain: 2.7
           original_leaf {
             scalar: -4.375
           }
         }
       }
       nodes {
         leaf {
           scalar: 6.54
         }
       }
       nodes {
         leaf {
           scalar: 7.305
         }
       }
       nodes {
         leaf {
           scalar: -4.525
         }
       }
       nodes {
         leaf {
           scalar: -4.145
         }
       }
     }
     trees {
       nodes {
         bucketized_split {
           feature_id: 75
           threshold: 21
           left_id: 1
           right_id: 2
         }
         metadata {
           gain: -1.4
         }
       }
       nodes {
         leaf {
           scalar: -0.6
         }
       }
       nodes {
         leaf {
           scalar: 0.165
         }
       }
     }
     tree_weights: 0.15
     tree_weights: 1.0
     tree_metadata {
       num_layers_grown: 2
       is_finalized: true
     }
     tree_metadata {
       num_layers_grown: 1
       is_finalized: false
     }
     growing_metadata {
       num_trees_attempted: 2
       num_layers_attempted: 6
       last_layer_node_start: 16
       last_layer_node_end: 19
     }
   """, ensemble_proto)
         ensemble = boosted_trees_ops.TreeEnsemble(
             'ensemble',
             stamp_token=7,
             serialized_proto=ensemble_proto.SerializeToString())
         resources.initialize_resources(resources.shared_resources()).run()
         (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
          nodes_range) = ensemble.get_states()
         self.assertEqual(7, self.evaluate(stamp_token))
         self.assertEqual(2, self.evaluate(num_trees))
         self.assertEqual(1, self.evaluate(num_finalized_trees))
         self.assertEqual(6, self.evaluate(num_attempted_layers))
         self.assertAllEqual([16, 19], self.evaluate(nodes_range))
예제 #22
0
  def testPostPruningOfSomeNodes(self):
    """Test growing an ensemble with post-pruning."""
    with self.test_session() as session:
      # Create empty ensemble.
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle

      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare inputs.
      # Second feature has larger (but still negative gain).
      feature_ids = [0, 1]

      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([-1.3], dtype=np.float32)
      feature1_thresholds = np.array([7], dtype=np.int32)
      feature1_left_node_contribs = np.array([[0.013]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[0.0143]], dtype=np.float32)

      feature2_nodes = np.array([0], dtype=np.int32)
      feature2_gains = np.array([-0.2], dtype=np.float32)
      feature2_thresholds = np.array([33], dtype=np.int32)
      feature2_left_node_contribs = np.array([[0.01]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.0143]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes],
          gains=[feature1_gains, feature2_gains],
          thresholds=[feature1_thresholds, feature2_thresholds],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs
          ])

      session.run(grow_op)

      # Expect the split from second features to be chosen despite the negative
      # gain.
      # No pruning happened just yet.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            leaf {
              scalar: 0.0143
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, res_ensemble)

      # Prepare the second layer.
      # Note that node 1 gain is negative and node 2 gain is positive.
      feature_ids = [3]
      feature1_nodes = np.array([1, 2], dtype=np.int32)
      feature1_gains = np.array([-0.2, 0.5], dtype=np.float32)
      feature1_thresholds = np.array([7, 5], dtype=np.int32)
      feature1_left_node_contribs = np.array(
          [[0.07], [0.041]], dtype=np.float32)
      feature1_right_node_contribs = np.array(
          [[0.083], [0.064]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes],
          gains=[feature1_gains],
          thresholds=[feature1_thresholds],
          left_node_contribs=[feature1_left_node_contribs],
          right_node_contribs=[feature1_right_node_contribs])

      session.run(grow_op)

      # After adding this layer, the tree will not be finalized
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)
      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id:1
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 3
              threshold: 7
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: -0.2
              original_leaf {
                scalar: 0.01
               }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 3
              threshold: 5
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 0.5
              original_leaf {
                scalar: 0.0143
               }
            }
          }
          nodes {
            leaf {
              scalar: 0.08
            }
          }
          nodes {
            leaf {
              scalar: 0.093
            }
          }
          nodes {
            leaf {
              scalar: 0.0553
            }
          }
          nodes {
            leaf {
                scalar: 0.0783
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 2
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
       """
      self.assertEqual(new_stamp, 2)

      self.assertProtoEquals(expected_result, res_ensemble)
      # Now split the leaf 3, again with negative gain. After this layer, the
      # tree will be finalized, and post-pruning happens. The leafs 3,4,7,8 will
      # be pruned out.

      # Prepare the third layer.
      feature_ids = [92]
      feature1_nodes = np.array([3], dtype=np.int32)
      feature1_gains = np.array([-0.45], dtype=np.float32)
      feature1_thresholds = np.array([11], dtype=np.int32)
      feature1_left_node_contribs = np.array([[0.15]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[0.5]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes],
          gains=[feature1_gains],
          thresholds=[feature1_thresholds],
          left_node_contribs=[feature1_left_node_contribs],
          right_node_contribs=[feature1_right_node_contribs])

      session.run(grow_op)
      # After adding this layer, the tree will be finalized
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)
      # Node that nodes 3, 4, 7 and 8 got deleted, so metadata stores has ids
      # mapped to their parent node 1, with the respective change in logits.
      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id:1
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            bucketized_split {
              feature_id: 3
              threshold: 5
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 0.5
              original_leaf {
                scalar: 0.0143
               }
            }
          }
          nodes {
            leaf {
              scalar: 0.0553
            }
          }
          nodes {
            leaf {
              scalar: 0.0783
            }
          }
        }
        trees {
          nodes {
            leaf {
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 3
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 2
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.07
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.083
          }
          post_pruned_nodes_meta {
            new_node_id: 3
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 4
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.22
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.57
          }
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 3
        }
       """
      self.assertEqual(new_stamp, 3)
      self.assertProtoEquals(expected_result, res_ensemble)
예제 #23
0
    def testCachedPredictionFromTheSameTree(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 7
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 1.4
              original_leaf {
                scalar: 7.14
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 7
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 2.7
              original_leaf {
                scalar: -4.375
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
          nodes {
            leaf {
              scalar: -5.875
            }
          }
          nodes {
            leaf {
              scalar: -2.075
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 0.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 0]

            # We have two features: 0 and 1.
            feature_0_values = [67, 5]
            feature_1_values = [9, 17]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are still in the same tree.
            self.assertAllClose([0, 0], new_tree_ids)
            # When using the full tree, the first example will end up in node 4,
            # the second in node 5.
            self.assertAllClose([4, 5], new_node_ids)
            # Full predictions for each instance would be 8.79 and -5.875,
            # so an update from the previous cached values lr*(7.14 and -2) would be
            # 1.65 and -3.875, and then multiply them by 0.1 (lr)
            self.assertAllClose([[0.1 * 1.65], [0.1 * -3.875]], logits_updates)
예제 #24
0
  def testCategoricalSplits(self):
    """Tests the training prediction work for categorical splits."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            categorical_split {
              feature_id: 1
              value: 2
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            categorical_split {
              feature_id: 0
              value: 13
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          is_finalized: true
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_0_values = [13, 1, 3]
      feature_1_values = [2, 2, 1]

      # No previous cached values.
      cached_tree_ids = [0, 0, 0]
      cached_node_ids = [0, 0, 0]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      self.assertAllClose([0, 0, 0], new_tree_ids)
      self.assertAllClose([3, 4, 2], new_node_ids)
      self.assertAllClose([[5.], [6.], [7.]], logits_updates)
예제 #25
0
    def testCachedPredictionFromPreviousTree(self):
        """Tests the predictions work when we have cache from previous trees."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7
            }
          }
          nodes {
            leaf {
              scalar: 5
            }
          }
          nodes {
            leaf {
              scalar: 6
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: false
        }
        tree_weights: 0.1
        tree_weights: 0.1
        tree_weights: 0.1
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 2.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 0]

            # We have two features: 0 and 1.
            feature_0_values = [36, 32]
            feature_1_values = [11, 27]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)
            # Example 1 will get to node 3 in tree 1 and node 2 of tree 2
            # Example 2 will get to node 2 in tree 1 and node 1 of tree 2

            # We are in the last tree.
            self.assertAllClose([2, 2], new_tree_ids)
            # When using the full tree, the first example will end up in node 4,
            # the second in node 5.
            self.assertAllClose([2, 1], new_node_ids)
            # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = >
            #            change = 0.1*(5.0+5.0)
            # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
            #            change= 0.1(1.14+7.0-7.0)
            self.assertAllClose([[1], [0.114]], logits_updates)
예제 #26
0
  def testCategoricalSplits(self):
    """Tests the predictions work for categorical splits."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            categorical_split {
              feature_id: 1
              value: 2
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            categorical_split {
              feature_id: 0
              value: 13
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.0
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_0_values = [13, 1, 3]
      feature_1_values = [2, 2, 1]

      expected_logits = [[5.], [6.], [7.]]

      # Prediction should work fine.
      predict_op = boosted_trees_ops.predict(
          tree_ensemble_handle,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits = session.run(predict_op)
      self.assertAllClose(expected_logits, logits)
예제 #27
0
    def testNoCachedPredictionButTreeExists(self):
        """Tests that predictions are updated once trees are added."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 1
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, none were cached before.
            cached_tree_ids = [0, 0]
            cached_node_ids = [0, 0]

            feature_0_values = [67, 5]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the first tree.
            self.assertAllClose([0, 0], new_tree_ids)
            self.assertAllClose([2, 1], new_node_ids)
            self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
예제 #28
0
def _bt_model_fn(
    features,
    labels,
    mode,
    head,
    feature_columns,
    tree_hparams,
    n_batches_per_layer,
    config,
    closed_form_grad_and_hess_fn=None,
    example_id_column_name=None,
    # TODO(youngheek): replace this later using other options.
    train_in_memory=False,
    name='boosted_trees'):
  """Gradient Boosted Trees model_fn.

  Args:
    features: dict of `Tensor`.
    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
      dtype `int32` or `int64` in the range `[0, n_classes)`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    head: A `head_lib._Head` instance.
    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
    tree_hparams: TODO. collections.namedtuple for hyper parameters.
    n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
      least n_batches_per_layer accumulations.
    config: `RunConfig` object to configure the runtime settings.
    closed_form_grad_and_hess_fn: a function that accepts logits and labels
      and returns gradients and hessians. By default, they are created by
      tf.gradients() from the loss.
    example_id_column_name: Name of the feature for a unique ID per example.
      Currently experimental -- not exposed to public API.
    train_in_memory: `bool`, when true, it assumes the dataset is in memory,
      i.e., input_fn should return the entire dataset as a single batch, and
      also n_batches_per_layer should be set as 1.
    name: Name to use for the model.

  Returns:
      An `EstimatorSpec` instance.

  Raises:
    ValueError: mode or params are invalid, or features has the wrong type.
  """
  is_single_machine = (config.num_worker_replicas <= 1)

  sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
  if train_in_memory:
    assert n_batches_per_layer == 1, (
        'When train_in_memory is enabled, input_fn should return the entire '
        'dataset as a single batch, and n_batches_per_layer should be set as '
        '1.')
    if (not config.is_chief or config.num_worker_replicas > 1 or
        config.num_ps_replicas > 0):
      raise ValueError('train_in_memory is supported only for '
                       'non-distributed training.')
  worker_device = control_flow_ops.no_op().device
  # maximum number of splits possible in the whole tree =2^(D-1)-1
  # TODO(youngheek): perhaps storage could be optimized by storing stats with
  # the dimension max_splits_per_layer, instead of max_splits (for the entire
  # tree).
  max_splits = (1 << tree_hparams.max_depth) - 1
  train_op = []
  with ops.name_scope(name) as name:
    # Prepare.
    global_step = training_util.get_or_create_global_step()
    bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
        sorted_feature_columns)
    # Extract input features and set up cache for training.
    training_state_cache = None
    if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
      # cache transformed features as well for in-memory training.
      batch_size = array_ops.shape(labels)[0]
      input_feature_list, input_cache_op = (
          _cache_transformed_features(features, sorted_feature_columns,
                                      batch_size))
      train_op.append(input_cache_op)
      training_state_cache = _CacheTrainingStatesUsingVariables(
          batch_size, head.logits_dimension)
    else:
      input_feature_list = _get_transformed_features(features,
                                                     sorted_feature_columns)
      if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
        example_ids = features[example_id_column_name]
        training_state_cache = _CacheTrainingStatesUsingHashTable(
            example_ids, head.logits_dimension)

    # Create Ensemble resources.
    tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
    # Create logits.
    if mode != model_fn.ModeKeys.TRAIN:
      logits = boosted_trees_ops.predict(
          # For non-TRAIN mode, ensemble doesn't change after initialization,
          # so no local copy is needed; using tree_ensemble directly.
          tree_ensemble_handle=tree_ensemble.resource_handle,
          bucketized_features=input_feature_list,
          logits_dimension=head.logits_dimension)
    else:
      if is_single_machine:
        local_tree_ensemble = tree_ensemble
        ensemble_reload = control_flow_ops.no_op()
      else:
        # Have a local copy of ensemble for the distributed setting.
        with ops.device(worker_device):
          local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
              name=name + '_local', is_local=True)
        # TODO(soroush): Do partial updates if this becomes a bottleneck.
        ensemble_reload = local_tree_ensemble.deserialize(
            *tree_ensemble.serialize())
      if training_state_cache:
        cached_tree_ids, cached_node_ids, cached_logits = (
            training_state_cache.lookup())
      else:
        # Always start from the beginning when no cache is set up.
        batch_size = array_ops.shape(labels)[0]
        cached_tree_ids, cached_node_ids, cached_logits = (
            array_ops.zeros([batch_size], dtype=dtypes.int32),
            array_ops.zeros([batch_size], dtype=dtypes.int32),
            array_ops.zeros(
                [batch_size, head.logits_dimension], dtype=dtypes.float32))
      with ops.control_dependencies([ensemble_reload]):
        (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
         last_layer_nodes_range) = local_tree_ensemble.get_states()
        summary.scalar('ensemble/num_trees', num_trees)
        summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
        summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)

        partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
            tree_ensemble_handle=local_tree_ensemble.resource_handle,
            cached_tree_ids=cached_tree_ids,
            cached_node_ids=cached_node_ids,
            bucketized_features=input_feature_list,
            logits_dimension=head.logits_dimension)
      logits = cached_logits + partial_logits

    # Create training graph.
    def _train_op_fn(loss):
      """Run one training iteration."""
      if training_state_cache:
        train_op.append(training_state_cache.insert(tree_ids, node_ids, logits))
      if closed_form_grad_and_hess_fn:
        gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
      else:
        gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0]
        hessians = gradients_impl.gradients(
            gradients, logits, name='Hessians')[0]

      stats_summaries_list = []
      for i, feature_ids in enumerate(feature_ids_list):
        num_buckets = bucket_size_list[i]
        summaries = [
            array_ops.squeeze(
                boosted_trees_ops.make_stats_summary(
                    node_ids=node_ids,
                    gradients=gradients,
                    hessians=hessians,
                    bucketized_features_list=[input_feature_list[f]],
                    max_splits=max_splits,
                    num_buckets=num_buckets),
                axis=0) for f in feature_ids
        ]
        stats_summaries_list.append(summaries)

      accumulators = []

      def grow_tree_from_stats_summaries(stats_summaries_list,
                                         feature_ids_list):
        """Updates ensemble based on the best gains from stats summaries."""
        node_ids_per_feature = []
        gains_list = []
        thresholds_list = []
        left_node_contribs_list = []
        right_node_contribs_list = []
        all_feature_ids = []

        assert len(stats_summaries_list) == len(feature_ids_list)

        for i, feature_ids in enumerate(feature_ids_list):
          (numeric_node_ids_per_feature, numeric_gains_list,
           numeric_thresholds_list, numeric_left_node_contribs_list,
           numeric_right_node_contribs_list) = (
               boosted_trees_ops.calculate_best_gains_per_feature(
                   node_id_range=last_layer_nodes_range,
                   stats_summary_list=stats_summaries_list[i],
                   l1=tree_hparams.l1,
                   l2=tree_hparams.l2,
                   tree_complexity=tree_hparams.tree_complexity,
                   min_node_weight=tree_hparams.min_node_weight,
                   max_splits=max_splits))

          all_feature_ids += feature_ids
          node_ids_per_feature += numeric_node_ids_per_feature
          gains_list += numeric_gains_list
          thresholds_list += numeric_thresholds_list
          left_node_contribs_list += numeric_left_node_contribs_list
          right_node_contribs_list += numeric_right_node_contribs_list

        grow_op = boosted_trees_ops.update_ensemble(
            # Confirm if local_tree_ensemble or tree_ensemble should be used.
            tree_ensemble.resource_handle,
            feature_ids=all_feature_ids,
            node_ids=node_ids_per_feature,
            gains=gains_list,
            thresholds=thresholds_list,
            left_node_contribs=left_node_contribs_list,
            right_node_contribs=right_node_contribs_list,
            learning_rate=tree_hparams.learning_rate,
            max_depth=tree_hparams.max_depth,
            pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
        return grow_op

      if train_in_memory and is_single_machine:
        train_op.append(distribute_lib.increment_var(global_step))
        train_op.append(
            grow_tree_from_stats_summaries(stats_summaries_list,
                                           feature_ids_list))
      else:
        dependencies = []

        for i, feature_ids in enumerate(feature_ids_list):
          stats_summaries = stats_summaries_list[i]
          accumulator = data_flow_ops.ConditionalAccumulator(
              dtype=dtypes.float32,
              # The stats consist of grads and hessians (the last dimension).
              shape=[len(feature_ids), max_splits, bucket_size_list[i], 2],
              shared_name='numeric_stats_summary_accumulator_' + str(i))
          accumulators.append(accumulator)

          apply_grad = accumulator.apply_grad(
              array_ops.stack(stats_summaries, axis=0), stamp_token)
          dependencies.append(apply_grad)

        def grow_tree_from_accumulated_summaries_fn():
          """Updates the tree with the best layer from accumulated summaries."""
          # Take out the accumulated summaries from the accumulator and grow.
          stats_summaries_list = []

          stats_summaries_list = [
              array_ops.unstack(accumulator.take_grad(1), axis=0)
              for accumulator in accumulators
          ]

          grow_op = grow_tree_from_stats_summaries(stats_summaries_list,
                                                   feature_ids_list)
          return grow_op

        with ops.control_dependencies(dependencies):
          train_op.append(distribute_lib.increment_var(global_step))
          if config.is_chief:
            min_accumulated = math_ops.reduce_min(
                array_ops.stack(
                    [acc.num_accumulated() for acc in accumulators]))

            train_op.append(
                control_flow_ops.cond(
                    math_ops.greater_equal(min_accumulated,
                                           n_batches_per_layer),
                    grow_tree_from_accumulated_summaries_fn,
                    control_flow_ops.no_op,
                    name='wait_until_n_batches_accumulated'))

      return control_flow_ops.group(train_op, name='train_op')

  estimator_spec = head.create_estimator_spec(
      features=features,
      mode=mode,
      labels=labels,
      train_op_fn=_train_op_fn,
      logits=logits)
  if mode == model_fn.ModeKeys.TRAIN:
    # Add an early stop hook.
    estimator_spec = estimator_spec._replace(
        training_hooks=estimator_spec.training_hooks +
        (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
                             tree_hparams.n_trees, tree_hparams.max_depth),))
  return estimator_spec
예제 #29
0
    def testPredictionMultipleTree(self):
        """Tests the predictions work when we have multiple trees."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_weights: 0.1
        tree_weights: 0.2
        tree_weights: 1.0
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [11, 27]

            # Example 1: tree 0: 1.14, tree 1: 5.0, tree 2: 5.0 = >
            #            logit = 0.1*5.0+0.2*5.0+1*5
            # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
            #            logit= 0.1*1.14+0.2*7.0-1*7.0
            expected_logits = [[6.114], [-5.486]]

            # Do with parallelization, e.g. EVAL
            predict_op = boosted_trees_ops.predict(
                tree_ensemble_handle,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits = session.run(predict_op)
            self.assertAllClose(expected_logits, logits)

            # Do without parallelization, e.g. INFER - the result is the same
            predict_op = boosted_trees_ops.predict(
                tree_ensemble_handle,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits = session.run(predict_op)
            self.assertAllClose(expected_logits, logits)
예제 #30
0
    def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
        """Tests case when, after training, first tree contains only a bias node."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 1.72
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
            metadata {
              original_leaf: {scalar: 5.5}
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.
        tree_weights: 0.1
        tree_metadata: {
          num_layers_grown: 0
        }
        tree_metadata: {
          num_layers_grown: 1
        }
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [13, -29]  # Unused feature.
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            expected_feature_ids = ((2, 0), (2, ))
            # bias = 1.72 * 1.  # Root node of tree_0.
            # example_0 :  (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
            # example_1 :  (bias, 0.1 * 7. + bias )
            expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)