def _get_train_op_and_ensemble(self, head, config, is_classification,
                                 train_in_memory):
    """Calls bt_model_fn() and returns the train_op and ensemble_serialzed."""
    features, labels = _make_train_input_fn(is_classification)()
    estimator_spec = boosted_trees._bt_model_fn(  # pylint:disable=protected-access
        features=features,
        labels=labels,
        mode=model_fn.ModeKeys.TRAIN,
        head=head,
        feature_columns=self._feature_columns,
        tree_hparams=self._tree_hparams,
        example_id_column_name=EXAMPLE_ID_COLUMN,
        n_batches_per_layer=1,
        config=config,
        train_in_memory=train_in_memory)
    resources.initialize_resources(resources.shared_resources()).run()
    variables.global_variables_initializer().run()
    variables.local_variables_initializer().run()

    # Gets the train_op and serialized proto of the ensemble.
    shared_resources = resources.shared_resources()
    self.assertEqual(1, len(shared_resources))
    train_op = estimator_spec.train_op
    with ops.control_dependencies([train_op]):
      _, ensemble_serialized = (
          gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
              shared_resources[0].handle))
    return train_op, ensemble_serialized
  def testBiasEnsembleMultiClass(self):
    with self.test_session():
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      tree = tree_ensemble_config.trees.add()
      tree_ensemble_config.tree_metadata.add().is_finalized = True
      leaf = tree.nodes.add().leaf
      _append_to_leaf(leaf, 0, -0.4)
      _append_to_leaf(leaf, 1, 0.9)

      tree_ensemble_config.tree_weights.append(1.0)

      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="multiclass")
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare learner config.
      learner_config = learner_pb2.LearnerConfig()
      learner_config.num_classes = 3

      result, dropout_info = self._get_predictions(
          tree_ensemble_handle,
          learner_config=learner_config.SerializeToString(),
          reduce_dim=True)
      self.assertAllClose([[-0.4, 0.9], [-0.4, 0.9]], result.eval())

      # Empty dropout.
      self.assertAllEqual([[], []], dropout_info.eval())
  def testTreeFinalized(self):
    with self.test_session():
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      # Depth 3 tree.
      tree1 = tree_ensemble_config.trees.add()
      _set_float_split(tree1.nodes.add().dense_float_binary_split, 0, 9.0, 1, 2)
      _set_float_split(tree1.nodes.add()
                       .sparse_float_binary_split_default_left.split, 0, -20.0,
                       3, 4)
      _append_to_leaf(tree1.nodes.add().leaf, 0, 0.2)
      _append_to_leaf(tree1.nodes.add().leaf, 0, 0.3)
      _set_categorical_id_split(tree1.nodes.add().categorical_id_binary_split,
                                0, 9, 5, 6)
      _append_to_leaf(tree1.nodes.add().leaf, 0, 0.5)
      _append_to_leaf(tree1.nodes.add().leaf, 0, 0.6)

      tree_ensemble_config.tree_weights.append(1.0)
      tree_ensemble_config.tree_metadata.add().is_finalized = True

      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="full_ensemble")
      resources.initialize_resources(resources.shared_resources()).run()

      result = prediction_ops.gradient_trees_partition_examples(
          tree_ensemble_handle, [self._dense_float_tensor], [
              self._sparse_float_indices1, self._sparse_float_indices2
          ], [self._sparse_float_values1, self._sparse_float_values2],
          [self._sparse_float_shape1,
           self._sparse_float_shape2], [self._sparse_int_indices1],
          [self._sparse_int_values1], [self._sparse_int_shape1])

      self.assertAllEqual([0, 0], result.eval())
  def testBasicQuantileBucketsMultipleResources(self):
    with self.test_session() as sess:
      quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
                                                           self.max_elements)
      quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
                                                           self.max_elements)
      resources.initialize_resources(resources.shared_resources()).run()
      summaries = boosted_trees_ops.make_quantile_summaries(
          [self._feature_0, self._feature_1], self._example_weights,
          epsilon=self.eps)
      summary_op_0 = boosted_trees_ops.quantile_add_summaries(
          quantile_accumulator_handle_0,
          [summaries[0]])
      summary_op_1 = boosted_trees_ops.quantile_add_summaries(
          quantile_accumulator_handle_1,
          [summaries[1]])
      flush_op_0 = boosted_trees_ops.quantile_flush(
          quantile_accumulator_handle_0, self.num_quantiles)
      flush_op_1 = boosted_trees_ops.quantile_flush(
          quantile_accumulator_handle_1, self.num_quantiles)
      bucket_0 = boosted_trees_ops.get_bucket_boundaries(
          quantile_accumulator_handle_0, num_features=1)
      bucket_1 = boosted_trees_ops.get_bucket_boundaries(
          quantile_accumulator_handle_1, num_features=1)
      quantiles = boosted_trees_ops.boosted_trees_bucketize(
          [self._feature_0, self._feature_1], bucket_0 + bucket_1)
      sess.run([summary_op_0, summary_op_1])
      sess.run([flush_op_0, flush_op_1])
      self.assertAllClose(self._feature_0_boundaries, bucket_0[0].eval())
      self.assertAllClose(self._feature_1_boundaries, bucket_1[0].eval())

      self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
      self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
  def testAverageMoreThanNumTreesExist(self):
    with self.test_session():
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      adjusted_tree_ensemble_config = (
          tree_config_pb2.DecisionTreeEnsembleConfig())
      # When we say to average over more trees than possible, it is averaging
      # across all trees.
      total_num = 100
      for i in range(0, total_num):
        tree = tree_ensemble_config.trees.add()
        _append_to_leaf(tree.nodes.add().leaf, 0, -0.4)

        tree_ensemble_config.tree_metadata.add().is_finalized = True
        tree_ensemble_config.tree_weights.append(1.0)
        # This is how the weight will look after averaging
        copy_tree = adjusted_tree_ensemble_config.trees.add()
        _append_to_leaf(copy_tree.nodes.add().leaf, 0, -0.4)

        adjusted_tree_ensemble_config.tree_metadata.add().is_finalized = True
        adjusted_tree_ensemble_config.tree_weights.append(
            1.0 * (total_num - i) / total_num)

      # Prepare learner config WITH AVERAGING.
      learner_config = learner_pb2.LearnerConfig()
      learner_config.num_classes = 2
      # We have only 100 trees but we ask to average over 250.
      learner_config.averaging_config.average_last_n_trees = 250

      # No averaging config.
      learner_config_no_averaging = learner_pb2.LearnerConfig()
      learner_config_no_averaging.num_classes = 2

      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="existing")

      # This is how our ensemble will "look" during averaging
      adjusted_tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=adjusted_tree_ensemble_config.SerializeToString(
          ),
          name="adjusted")

      resources.initialize_resources(resources.shared_resources()).run()

      result, dropout_info = self._get_predictions(
          tree_ensemble_handle,
          learner_config.SerializeToString(),
          apply_averaging=True,
          reduce_dim=True)

      pattern_result, pattern_dropout_info = self._get_predictions(
          adjusted_tree_ensemble_handle,
          learner_config_no_averaging.SerializeToString(),
          apply_averaging=False,
          reduce_dim=True)

      self.assertAllEqual(result.eval(), pattern_result.eval())
      self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval())
  def testCreate(self):
    with self.cached_session():
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      tree = tree_ensemble_config.trees.add()
      _append_to_leaf(tree.nodes.add().leaf, 0, -0.4)
      tree_ensemble_config.tree_weights.append(1.0)

      # Prepare learner config.
      learner_config = learner_pb2.LearnerConfig()
      learner_config.num_classes = 2

      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=3,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="create_tree")
      resources.initialize_resources(resources.shared_resources()).run()

      result, _ = prediction_ops.gradient_trees_prediction(
          tree_ensemble_handle,
          self._seed, [self._dense_float_tensor], [
              self._sparse_float_indices1, self._sparse_float_indices2
          ], [self._sparse_float_values1, self._sparse_float_values2],
          [self._sparse_float_shape1,
           self._sparse_float_shape2], [self._sparse_int_indices1],
          [self._sparse_int_values1], [self._sparse_int_shape1],
          learner_config=learner_config.SerializeToString(),
          apply_dropout=False,
          apply_averaging=False,
          center_bias=False,
          reduce_dim=True)
      self.assertAllClose(result.eval(), [[-0.4], [-0.4]])
      stamp_token = model_ops.tree_ensemble_stamp_token(tree_ensemble_handle)
      self.assertEqual(stamp_token.eval(), 3)
  def testSaveRestoreBeforeFlush(self):
    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

    with self.cached_session() as sess:
      accumulator = boosted_trees_ops.QuantileAccumulator(
          num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")

      save = saver.Saver()
      resources.initialize_resources(resources.shared_resources()).run()

      summaries = accumulator.add_summaries([self._feature_0, self._feature_1],
                                            self._example_weights)
      self.evaluate(summaries)
      buckets = accumulator.get_bucket_boundaries()
      self.assertAllClose([], buckets[0].eval())
      self.assertAllClose([], buckets[1].eval())
      save.save(sess, save_path)
      self.evaluate(accumulator.flush())
      self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
      self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())

    with self.session(graph=ops.Graph()) as sess:
      accumulator = boosted_trees_ops.QuantileAccumulator(
          num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
      save = saver.Saver()
      save.restore(sess, save_path)
      buckets = accumulator.get_bucket_boundaries()
      self.assertAllClose([], buckets[0].eval())
      self.assertAllClose([], buckets[1].eval())
  def testCachedPredictionOnEmptyEnsemble(self):
    """Tests that prediction on a dummy ensemble does not fail."""
    with self.cached_session() as session:
      # Create a dummy ensemble.
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto='')
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # No previous cached values.
      cached_tree_ids = [0, 0]
      cached_node_ids = [0, 0]

      # We have two features: 0 and 1. Values don't matter here on a dummy
      # ensemble.
      feature_0_values = [67, 5]
      feature_1_values = [9, 17]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # Nothing changed.
      self.assertAllClose(cached_tree_ids, new_tree_ids)
      self.assertAllClose(cached_node_ids, new_node_ids)
      self.assertAllClose([[0], [0]], logits_updates)
  def _testStreamingQuantileBucketsHelper(
      self, inputs, num_quantiles=3, expected_buckets=None):
    """Helper to test quantile buckets on different inputs."""

    # set generate_quantiles to True since the test will generate fewer
    # boundaries otherwise.
    with self.test_session() as sess:
      accumulator = quantile_ops.QuantileAccumulator(
          init_stamp_token=0, num_quantiles=num_quantiles,
          epsilon=0.001, name="q1", generate_quantiles=True)
      resources.initialize_resources(resources.shared_resources()).run()
    input_column = array_ops.placeholder(dtypes.float32)
    weights = array_ops.placeholder(dtypes.float32)
    update = accumulator.add_summary(
        stamp_token=0,
        column=input_column,
        example_weights=weights)

    with self.test_session() as sess:
      sess.run(update,
               {input_column: inputs,
                weights: [1] * len(inputs)})

    with self.test_session() as sess:
      sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
      are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
      buckets, are_ready_flush = (sess.run(
          [buckets, are_ready_flush]))
      self.assertEqual(True, are_ready_flush)
      # By default, use 3 quantiles, 4 boundaries for simplicity.
      self.assertEqual(num_quantiles + 1, len(buckets))
      if expected_buckets:
        self.assertAllEqual(buckets, expected_buckets)
  def testStreamingQuantileBuckets(self):
    """Sets up the quantile summary op test as follows.

    100 batches of data is added to the accumulator. The batches are in form:
    [0 1 .. 99]
    [100 101 .. 200]
    ...
    [9900 9901 .. 9999]
    All the batches have 1 for all the example weights.
    """
    with self.test_session() as sess:
      accumulator = quantile_ops.QuantileAccumulator(
          init_stamp_token=0, num_quantiles=3, epsilon=0.01, name="q1")
      resources.initialize_resources(resources.shared_resources()).run()
    weight_placeholder = array_ops.placeholder(dtypes.float32)
    dense_placeholder = array_ops.placeholder(dtypes.float32)
    update = accumulator.add_summary(
        stamp_token=0,
        column=dense_placeholder,
        example_weights=weight_placeholder)
    with self.test_session() as sess:
      for i in range(100):
        dense_float = np.linspace(
            i * 100, (i + 1) * 100 - 1, num=100).reshape(-1, 1)
        sess.run(update, {
            dense_placeholder: dense_float,
            weight_placeholder: np.ones(shape=(100, 1), dtype=np.float32)
        })

    with self.test_session() as sess:
      sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
      are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
      buckets, are_ready_flush = (sess.run([buckets, are_ready_flush]))
      self.assertEqual(True, are_ready_flush)
      self.assertAllEqual([0, 3335., 6671., 9999.], buckets)
  def testContribsForOnlyABiasNode(self):
    """Tests case when, after training, only left with a bias node.

    For example, this could happen if the final ensemble contains one tree that
    got pruned up to the root.
    """
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            leaf {
              scalar: 1.72
            }
          }
        }
        tree_weights: 0.1
        tree_metadata: {
          num_layers_grown: 0
        }
      """, tree_ensemble_config)

      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # All features are unused.
      feature_0_values = [36, 32]
      feature_1_values = [13, -29]
      feature_2_values = [11, 27]

      # Expected logits are computed by traversing the logit path and
      # subtracting child logits from parent logits.
      bias = 1.72 * 0.1  # Root node of tree_0.
      expected_feature_ids = ((), ())
      expected_logits_paths = ((bias,), (bias,))

      bucketized_features = [
          feature_0_values, feature_1_values, feature_2_values
      ]

      debug_op = boosted_trees_ops.example_debug_outputs(
          tree_ensemble_handle,
          bucketized_features=bucketized_features,
          logits_dimension=1)

      serialized_examples_debug_outputs = session.run(debug_op)
      feature_ids = []
      logits_paths = []
      for example in serialized_examples_debug_outputs:
        example_debug_outputs = boosted_trees_pb2.DebugOutput()
        example_debug_outputs.ParseFromString(example)
        feature_ids.append(example_debug_outputs.feature_ids)
        logits_paths.append(example_debug_outputs.logits_path)

      self.assertAllClose(feature_ids, expected_feature_ids)
      self.assertAllClose(logits_paths, expected_logits_paths)
Example #12
0
  def testDropout(self):
    with self.test_session():
      # Empty tree ensenble.
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      # Add 1000 trees with some weights.
      for i in range(0, 999):
        tree = tree_ensemble_config.trees.add()
        tree_ensemble_config.tree_metadata.add().is_finalized = True
        _append_to_leaf(tree.nodes.add().leaf, 0, -0.4)
        tree_ensemble_config.tree_weights.append(i + 1)

      # Prepare learner/dropout config.
      learner_config = learner_pb2.LearnerConfig()
      learner_config.learning_rate_tuner.dropout.dropout_probability = 0.5
      learner_config.learning_rate_tuner.dropout.learning_rate = 1.0
      learner_config.num_classes = 2

      # Apply dropout.
      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="existing")
      resources.initialize_resources(resources.shared_resources()).run()

      result, dropout_info = self._get_predictions(
          tree_ensemble_handle,
          learner_config=learner_config.SerializeToString(),
          apply_dropout=True,
          apply_averaging=False,
          center_bias=False,
          reduce_dim=True)

      # We expect approx 500 trees were dropped.
      dropout_info = dropout_info.eval()
      self.assertIn(dropout_info[0].size, range(400, 601))
      self.assertEqual(dropout_info[0].size, dropout_info[1].size)

      for i in range(dropout_info[0].size):
        dropped_index = dropout_info[0][i]
        dropped_weight = dropout_info[1][i]
        # We constructed the trees so tree number + 1 is the tree weight, so
        # we can check here the weights for dropped trees.
        self.assertEqual(dropped_index + 1, dropped_weight)

      # Don't apply dropout.
      result_no_dropout, no_dropout_info = self._get_predictions(
          tree_ensemble_handle,
          learner_config=learner_config.SerializeToString(),
          apply_dropout=False,
          apply_averaging=False,
          center_bias=False,
          reduce_dim=True)

      self.assertEqual(result.eval().size, result_no_dropout.eval().size)
      for i in range(result.eval().size):
        self.assertNotEqual(result.eval()[i], result_no_dropout.eval()[i])

      # We expect none of the trees were dropped.
      self.assertAllEqual([[], []], no_dropout_info.eval())
  def testWithExistingEnsembleAndShrinkage(self):
    with self.test_session():
      # Add shrinkage config.
      learning_rate = 0.0001
      tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig()
      # Add 10 trees with some weights.
      for i in range(0, 5):
        tree = tree_ensemble.trees.add()
        _append_to_leaf(tree.nodes.add().leaf, 0, -0.4)
        tree_ensemble.tree_weights.append(i + 1)
        meta = tree_ensemble.tree_metadata.add()
        meta.num_tree_weight_updates = 1
      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble.SerializeToString(),
          name="existing")

      # Create non-zero feature importance.
      feature_usage_counts = variables.Variable(
          initial_value=np.array([4, 7], np.int64),
          name="feature_usage_counts",
          trainable=False)
      feature_gains = variables.Variable(
          initial_value=np.array([0.2, 0.8], np.float32),
          name="feature_gains",
          trainable=False)

      resources.initialize_resources(resources.shared_resources()).run()
      variables.initialize_all_variables().run()

      output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig()
      with ops.control_dependencies([
          ensemble_optimizer_ops.add_trees_to_ensemble(
              tree_ensemble_handle,
              self._ensemble_to_add.SerializeToString(),
              feature_usage_counts, [1, 2],
              feature_gains, [0.5, 0.3], [[], []],
              learning_rate=learning_rate)
      ]):
        output_ensemble.ParseFromString(
            model_ops.tree_ensemble_serialize(tree_ensemble_handle)[1].eval())

      # The weights of previous trees stayed the same, new tree (LAST) is added
      # with shrinkage weight.
      self.assertAllClose([1.0, 2.0, 3.0, 4.0, 5.0, learning_rate],
                          output_ensemble.tree_weights)

      # Check that all number of updates are equal to 1 (e,g, no old tree weight
      # got adjusted.
      for i in range(0, 6):
        self.assertEqual(
            1, output_ensemble.tree_metadata[i].num_tree_weight_updates)

      # Ensure feature importance was aggregated correctly.
      self.assertAllEqual([5, 9], feature_usage_counts.eval())
      self.assertArrayNear(
          [0.2 + 0.5 * learning_rate, 0.8 + 0.3 * learning_rate],
          feature_gains.eval(), 1e-6)
 def testCreate(self):
   with self.test_session():
     ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
     resources.initialize_resources(resources.shared_resources()).run()
     stamp_token = ensemble.get_stamp_token()
     self.assertEqual(0, stamp_token.eval())
     (_, num_trees, num_finalized_trees,
      num_attempted_layers) = ensemble.get_states()
     self.assertEqual(0, num_trees.eval())
     self.assertEqual(0, num_finalized_trees.eval())
     self.assertEqual(0, num_attempted_layers.eval())
Example #15
0
 def testCreate(self):
   with self.cached_session():
     ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
     resources.initialize_resources(resources.shared_resources()).run()
     stamp_token = ensemble.get_stamp_token()
     self.assertEqual(0, self.evaluate(stamp_token))
     (_, num_trees, num_finalized_trees, num_attempted_layers,
      nodes_range) = ensemble.get_states()
     self.assertEqual(0, self.evaluate(num_trees))
     self.assertEqual(0, self.evaluate(num_finalized_trees))
     self.assertEqual(0, self.evaluate(num_attempted_layers))
     self.assertAllEqual([0, 1], self.evaluate(nodes_range))
Example #16
0
  def testPredictFn(self):
    """Tests the predict function."""
    with self.test_session() as sess:
      # Create ensemble with one bias node.
      ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      text_format.Merge("""
          trees {
            nodes {
              leaf {
                vector {
                  value: 0.25
                }
              }
            }
          }
          tree_weights: 1.0
          tree_metadata {
            num_tree_weight_updates: 1
            num_layers_grown: 1
            is_finalized: true
          }""", ensemble_config)
      ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=3,
          tree_ensemble_config=ensemble_config.SerializeToString(),
          name="tree_ensemble")
      resources.initialize_resources(resources.shared_resources()).run()
      learner_config = learner_pb2.LearnerConfig()
      learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
      learner_config.num_classes = 2
      learner_config.regularization.l1 = 0
      learner_config.regularization.l2 = 0
      learner_config.constraints.max_tree_depth = 1
      learner_config.constraints.min_node_weight = 0
      features = {}
      features["dense_float"] = array_ops.ones([4, 1], dtypes.float32)
      gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
          is_chief=False,
          num_ps_replicas=0,
          center_bias=True,
          ensemble_handle=ensemble_handle,
          examples_per_layer=1,
          learner_config=learner_config,
          features=features)

      # Create predict op.
      mode = model_fn.ModeKeys.EVAL
      predictions_dict = sess.run(gbdt_model.predict(mode))
      self.assertEquals(predictions_dict["ensemble_stamp"], 3)
      self.assertAllClose(predictions_dict["predictions"], [[0.25], [0.25],
                                                            [0.25], [0.25]])
      self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0])
  def testSaveRestoreBeforeFlush(self):
    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

    with self.test_session(graph=ops.Graph()) as sess:
      accumulator = quantile_ops.QuantileAccumulator(
          init_stamp_token=0, num_quantiles=3, epsilon=0.33, name="q0")

      save = saver.Saver()
      resources.initialize_resources(resources.shared_resources()).run()

      sparse_indices_0 = constant_op.constant(
          [[1, 0], [2, 1], [3, 0], [4, 2], [5, 0]], dtype=dtypes.int64)
      sparse_values_0 = constant_op.constant(
          [2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtypes.float32)
      sparse_shape_0 = constant_op.constant([6, 3], dtype=dtypes.int64)
      example_weights = constant_op.constant(
          [10, 1, 1, 1, 1, 1], dtype=dtypes.float32, shape=[6, 1])
      update = accumulator.add_summary(
          stamp_token=0,
          column=sparse_tensor.SparseTensor(sparse_indices_0, sparse_values_0,
                                            sparse_shape_0),
          example_weights=example_weights)
      update.run()
      save.save(sess, save_path)
      reset = accumulator.flush(stamp_token=0, next_stamp_token=1)
      with ops.control_dependencies([reset]):
        are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
      buckets, are_ready_flush = (sess.run([buckets, are_ready_flush]))
      self.assertEqual(True, are_ready_flush)
      self.assertAllEqual([2, 4, 6.], buckets)

    with self.test_session(graph=ops.Graph()) as sess:
      accumulator = quantile_ops.QuantileAccumulator(
          init_stamp_token=0, num_quantiles=3, epsilon=0.33, name="q0")
      save = saver.Saver()

      # Restore the saved values in the parameter nodes.
      save.restore(sess, save_path)
      are_ready_noflush = accumulator.get_buckets(stamp_token=0)[0]
      with ops.control_dependencies([are_ready_noflush]):
        reset = accumulator.flush(stamp_token=0, next_stamp_token=1)

      with ops.control_dependencies([reset]):
        are_ready_flush, buckets = accumulator.get_buckets(stamp_token=1)
      buckets, are_ready_flush, are_ready_noflush = (sess.run(
          [buckets, are_ready_flush, are_ready_noflush]))
      self.assertFalse(are_ready_noflush)
      self.assertTrue(are_ready_flush)
      self.assertAllEqual([2, 4, 6.], buckets)
Example #18
0
 def _init_graph(self):
     # Initialize all weights
     if not self._is_initialized:
         self.saver = tf.train.Saver()
         init_vars = tf.group(tf.global_variables_initializer(),
                              resources.initialize_resources(
                                  resources.shared_resources()))
         self.session.run(init_vars)
         self._is_initialized = True
     # Restore weights if needed
     if self._to_be_restored:
         self.saver = tf.train.Saver()
         self.saver.restore(self.session, self._to_be_restored)
         self._to_be_restored = False
Example #19
0
def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
  """Run `output_dict` tensors with each input in `feed_dicts`.

  If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise,
  init all variables.

  Args:
    output_dict: A `dict` mapping string names to `Tensor` objects to run.
      Tensors must all be from the same graph.
    feed_dicts: Iterable of `dict` objects of input values to feed.
    restore_checkpoint_path: A string containing the path to a checkpoint to
      restore.

  Yields:
    A sequence of dicts of values read from `output_dict` tensors, one item
    yielded for each item in `feed_dicts`. Keys are the same as `output_dict`,
    values are the results read from the corresponding `Tensor` in
    `output_dict`.

  Raises:
    ValueError: if `output_dict` or `feed_dicts` is None or empty.
  """
  if not output_dict:
    raise ValueError('output_dict is invalid: %s.' % output_dict)
  if not feed_dicts:
    raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)

  graph = contrib_ops.get_graph_from_inputs(output_dict.values())
  with graph.as_default() as g:
    with tf_session.Session('') as session:
      session.run(
          resources.initialize_resources(resources.shared_resources() +
                                         resources.local_resources()))
      if restore_checkpoint_path:
        _restore_from_checkpoint(session, g, restore_checkpoint_path)
      else:
        session.run(variables.global_variables_initializer())
      session.run(variables.local_variables_initializer())
      session.run(data_flow_ops.initialize_all_tables())
      coord = coordinator.Coordinator()
      threads = None
      try:
        threads = queue_runner.start_queue_runners(session, coord=coord)
        for f in feed_dicts:
          yield session.run(output_dict, f)
      finally:
        coord.request_stop()
        if threads:
          coord.join(threads, stop_grace_period_secs=120)
Example #20
0
  def testMetadataMissing(self):
    # Sometimes we want to do prediction on trees that are not added to ensemble
    # (for example in
    with self.test_session():
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      # Bias tree.
      tree1 = tree_ensemble_config.trees.add()
      _append_to_leaf(tree1.nodes.add().leaf, 0, -0.4)

      # Depth 3 tree.
      tree2 = tree_ensemble_config.trees.add()
      # We are not setting the tree_ensemble_config.tree_metadata in this test.
      _set_float_split(tree2.nodes.add().dense_float_binary_split, 0, 9.0, 1, 2)
      _set_float_split(tree2.nodes.add()
                       .sparse_float_binary_split_default_left.split, 0, -20.0,
                       3, 4)
      _append_to_leaf(tree2.nodes.add().leaf, 0, 0.5)
      _append_to_leaf(tree2.nodes.add().leaf, 0, 1.2)
      _set_categorical_id_split(tree2.nodes.add().categorical_id_binary_split,
                                0, 9, 5, 6)
      _append_to_leaf(tree2.nodes.add().leaf, 0, -0.9)
      _append_to_leaf(tree2.nodes.add().leaf, 0, 0.7)

      tree_ensemble_config.tree_weights.append(1.0)
      tree_ensemble_config.tree_weights.append(1.0)

      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="full_ensemble")
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare learner config.
      learner_config = learner_pb2.LearnerConfig()
      learner_config.num_classes = 2

      result, dropout_info = self._get_predictions(
          tree_ensemble_handle,
          learner_config=learner_config.SerializeToString(),
          reduce_dim=True)

      # The first example will get bias -0.4 from first tree and
      # leaf 4 payload of -0.9 hence -1.3, the second example will
      # get the same bias -0.4 and leaf 3 payload (sparse feature missing)
      # of 1.2 hence 0.8.
      self.assertAllClose([[-1.3], [0.8]], result.eval())

      # Empty dropout.
      self.assertAllEqual([[], []], dropout_info.eval())
Example #21
0
 def testUsedHandlers(self):
   with self.cached_session():
     tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
     tree_ensemble_config.growing_metadata.used_handler_ids.append(1)
     tree_ensemble_config.growing_metadata.used_handler_ids.append(5)
     stamp_token = 3
     tree_ensemble_handle = model_ops.tree_ensemble_variable(
         stamp_token=stamp_token,
         tree_ensemble_config=tree_ensemble_config.SerializeToString(),
         name="create_tree")
     resources.initialize_resources(resources.shared_resources()).run()
     result = model_ops.tree_ensemble_used_handlers(
         tree_ensemble_handle, stamp_token, num_all_handlers=6)
     self.assertAllEqual([0, 1, 0, 0, 0, 1], result.used_handlers_mask.eval())
     self.assertEqual(2, result.num_used_handlers.eval())
  def testInactive(self):
    with self.test_session() as sess:
      gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
      hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
      partition_ids = [0, 0, 0, 1]
      indices = [[0, 0], [0, 1], [2, 0], [3, 0]]
      values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64)

      gradient_shape = tensor_shape.scalar()
      hessian_shape = tensor_shape.scalar()
      class_id = -1

      split_handler = categorical_split_handler.EqualitySplitHandler(
          l1_regularization=0.1,
          l2_regularization=1,
          tree_complexity_regularization=0,
          min_node_weight=0,
          sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
          feature_column_group_id=0,
          gradient_shape=gradient_shape,
          hessian_shape=hessian_shape,
          multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
          init_stamp_token=0)
      resources.initialize_resources(resources.shared_resources()).run()

      empty_gradients, empty_hessians = get_empty_tensors(
          gradient_shape, hessian_shape)
      example_weights = array_ops.ones([4, 1], dtypes.float32)

      update_1 = split_handler.update_stats_sync(
          0,
          partition_ids,
          gradients,
          hessians,
          empty_gradients,
          empty_hessians,
          example_weights,
          is_active=array_ops.constant([False, False]))
      with ops.control_dependencies([update_1]):
        are_splits_ready, partitions, gains, splits = (
            split_handler.make_splits(0, 1, class_id))
        are_splits_ready, partitions, gains, splits = (sess.run(
            [are_splits_ready, partitions, gains, splits]))
    self.assertTrue(are_splits_ready)
    self.assertEqual(len(partitions), 0)
    self.assertEqual(len(gains), 0)
    self.assertEqual(len(splits), 0)
Example #23
0
  def testExcludeNonFinalTree(self):
    with self.test_session():
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      # Bias tree.
      tree1 = tree_ensemble_config.trees.add()
      tree_ensemble_config.tree_metadata.add().is_finalized = True
      _append_to_leaf(tree1.nodes.add().leaf, 0, -0.4)

      # Depth 3 tree.
      tree2 = tree_ensemble_config.trees.add()
      tree_ensemble_config.tree_metadata.add().is_finalized = False
      _set_float_split(tree2.nodes.add().dense_float_binary_split, 0, 9.0, 1, 2)
      _set_float_split(tree2.nodes.add()
                       .sparse_float_binary_split_default_left.split, 0, -20.0,
                       3, 4)
      _append_to_leaf(tree2.nodes.add().leaf, 0, 0.5)
      _append_to_leaf(tree2.nodes.add().leaf, 0, 1.2)
      _set_categorical_id_split(tree2.nodes.add().categorical_id_binary_split,
                                0, 9, 5, 6)
      _append_to_leaf(tree2.nodes.add().leaf, 0, -0.9)
      _append_to_leaf(tree2.nodes.add().leaf, 0, 0.7)

      tree_ensemble_config.tree_weights.append(1.0)
      tree_ensemble_config.tree_weights.append(1.0)

      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="full_ensemble")
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare learner config.
      learner_config = learner_pb2.LearnerConfig()
      learner_config.num_classes = 2
      learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE

      result, dropout_info = self._get_predictions(
          tree_ensemble_handle,
          learner_config=learner_config.SerializeToString(),
          reduce_dim=True)

      # All the examples should get only the bias since the second tree is
      # non-finalized
      self.assertAllClose([[-0.4], [-0.4]], result.eval())

      # Empty dropout.
      self.assertAllEqual([[], []], dropout_info.eval())
  def testStreamingQuantileBucketsLowPrecisionInput(self):
    """Tests inputs that simulate low precision float16 values."""

    num_quantiles = 3
    # set generate_quantiles to True since the test will generate fewer
    # boundaries otherwise.
    with self.test_session() as sess:
      accumulator = quantile_ops.QuantileAccumulator(
          init_stamp_token=0, num_quantiles=num_quantiles,
          epsilon=0.001, name="q1", generate_quantiles=True)
      resources.initialize_resources(resources.shared_resources()).run()
    input_column = array_ops.placeholder(dtypes.float32)
    weights = array_ops.placeholder(dtypes.float32)
    update = accumulator.add_summary(
        stamp_token=0,
        column=input_column,
        example_weights=weights)

    with self.test_session() as sess:
      # This input is generated by integer in the range [2030, 2060]
      # but represented by with float16 precision. Integers <= 2048 are
      # exactly represented, whereas  numbers > 2048 are rounded; and hence
      # numbers > 2048 are repeated. For precision loss / rounding, see:
      # https://en.wikipedia.org/wiki/Half-precision_floating-point_format.
      #
      # The intent of the test is not handling of float16 values, but to
      # validate the number of buckets is returned, in cases where  the input
      # may contain repeated values.
      inputs = [
          2030.0, 2031.0, 2032.0, 2033.0, 2034.0, 2035.0, 2036.0, 2037.0,
          2038.0, 2039.0, 2040.0, 2041.0, 2042.0, 2043.0, 2044.0, 2045.0,
          2046.0, 2047.0, 2048.0, 2048.0, 2050.0, 2052.0, 2052.0, 2052.0,
          2054.0, 2056.0, 2056.0, 2056.0, 2058.0, 2060.0
      ]
      sess.run(update,
               {input_column: inputs,
                weights: [1] * len(inputs)})

    with self.test_session() as sess:
      sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
      are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
      buckets, are_ready_flush = (sess.run(
          [buckets, are_ready_flush]))
      self.assertEqual(True, are_ready_flush)
      self.assertEqual(num_quantiles + 1, len(buckets))
      self.assertAllEqual([2030, 2040, 2050, 2060], buckets)
Example #25
0
  def testFullEnsembleMultiNotClassTreePerClassStrategyDenseVector(self):
    with self.test_session():
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      # Bias tree only for second class.
      tree1 = tree_ensemble_config.trees.add()
      tree_ensemble_config.tree_metadata.add().is_finalized = True
      _append_multi_values_to_dense_leaf(tree1.nodes.add().leaf, [0, -0.2, -2])

      # Depth 2 tree.
      tree2 = tree_ensemble_config.trees.add()
      tree_ensemble_config.tree_metadata.add().is_finalized = True
      _set_float_split(tree2.nodes.add()
                       .sparse_float_binary_split_default_right.split, 1, 4.0,
                       1, 2)
      _set_float_split(tree2.nodes.add().dense_float_binary_split, 0, 9.0, 3, 4)
      _append_multi_values_to_dense_leaf(tree2.nodes.add().leaf, [0.5, 0, 0])
      _append_multi_values_to_dense_leaf(tree2.nodes.add().leaf, [0, 1.2, -0.7])
      _append_multi_values_to_dense_leaf(tree2.nodes.add().leaf, [-0.9, 0, 0])

      tree_ensemble_config.tree_weights.append(1.0)
      tree_ensemble_config.tree_weights.append(1.0)

      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="ensemble_multi_class")
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare learner config.
      learner_config = learner_pb2.LearnerConfig()
      learner_config.num_classes = 3
      learner_config.multi_class_strategy = (
          learner_pb2.LearnerConfig.FULL_HESSIAN)

      result, dropout_info = self._get_predictions(
          tree_ensemble_handle,
          learner_config=learner_config.SerializeToString(),
          reduce_dim=False)
      # The first example will get bias class 1 -0.2 and -2 for class 2 from
      # first tree and leaf 2 payload (sparse feature missing) of 0.5 hence
      # 0.5, -0.2], the second example will get the same bias and leaf 3 payload
      # of class 1 1.2 and class 2-0.7 hence [0.0, 1.0, -2.7].
      self.assertAllClose([[0.5, -0.2, -2.0], [0, 1.0, -2.7]], result.eval())

      # Empty dropout.
      self.assertAllEqual([[], []], dropout_info.eval())
  def testWithExistingEnsemble(self):
    with self.test_session():
      # Create existing tree ensemble.
      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=self._tree_ensemble.SerializeToString(),
          name="existing")
      # Create non-zero feature importance.
      feature_usage_counts = variables.Variable(
          initial_value=np.array([0, 4, 1], np.int64),
          name="feature_usage_counts",
          trainable=False)
      feature_gains = variables.Variable(
          initial_value=np.array([0.0, 0.3, 0.05], np.float32),
          name="feature_gains",
          trainable=False)

      resources.initialize_resources(resources.shared_resources()).run()
      variables.initialize_all_variables().run()
      output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig()
      with ops.control_dependencies([
          ensemble_optimizer_ops.add_trees_to_ensemble(
              tree_ensemble_handle,
              self._ensemble_to_add.SerializeToString(),
              feature_usage_counts, [1, 2, 0],
              feature_gains, [0.02, 0.1, 0.0], [[], []],
              learning_rate=1)
      ]):
        output_ensemble.ParseFromString(
            model_ops.tree_ensemble_serialize(tree_ensemble_handle)[1].eval())

      # Output.
      self.assertEqual(3, len(output_ensemble.trees))
      self.assertProtoEquals(self._tree_to_add, output_ensemble.trees[2])

      self.assertAllEqual([1.0, 1.0, 1.0], output_ensemble.tree_weights)

      self.assertEqual(2,
                       output_ensemble.tree_metadata[0].num_tree_weight_updates)
      self.assertEqual(3,
                       output_ensemble.tree_metadata[1].num_tree_weight_updates)
      self.assertEqual(1,
                       output_ensemble.tree_metadata[2].num_tree_weight_updates)
      self.assertAllEqual([1, 6, 1], feature_usage_counts.eval())
      self.assertArrayNear([0.02, 0.4, 0.05], feature_gains.eval(), 1e-6)
  def testStreamingQuantileBucketsWithVaryingBatch(self):
    """Sets up the quantile summary op test as follows.

    Creates batches examples with different number of inputs in each batch.
    The input values are dense in the range [1 ... N]
    The data looks like this:
    | Batch | Start | InputList
    |   1   |   1   |  [1]
    |   2   |   2   |  [2, 3]
    |   3   |   4   |  [4, 5, 6]
    |   4   |   7   |  [7, 8, 9, 10]
    |   5   |  11   |  [11, 12, 13, 14, 15]
    |   6   |  16   |  [16, 17, 18, 19, 20, 21]
    """

    num_quantiles = 3
    with self.test_session() as sess:
      accumulator = quantile_ops.QuantileAccumulator(
          init_stamp_token=0, num_quantiles=num_quantiles,
          epsilon=0.001, name="q1")
      resources.initialize_resources(resources.shared_resources()).run()
    input_column = array_ops.placeholder(dtypes.float32)
    weights = array_ops.placeholder(dtypes.float32)
    update = accumulator.add_summary(
        stamp_token=0,
        column=input_column,
        example_weights=weights)

    with self.test_session() as sess:
      for i in range(1, 23):
        # start = 1, 2, 4, 7, 11, 16 ... (see comment above)
        start = int((i * (i-1) / 2) + 1)
        sess.run(update,
                 {input_column: range(start, start+i),
                  weights: [1] * i})

    with self.test_session() as sess:
      sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
      are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
      buckets, are_ready_flush = (sess.run(
          [buckets, are_ready_flush]))
      self.assertEqual(True, are_ready_flush)
      self.assertEqual(num_quantiles + 1, len(buckets))
      self.assertAllEqual([1, 86., 170., 253.], buckets)
Example #28
0
  def testEnsembleEmpty(self):
    with self.test_session():
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()

      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="full_ensemble")
      resources.initialize_resources(resources.shared_resources()).run()

      result = prediction_ops.gradient_trees_partition_examples(
          tree_ensemble_handle, [self._dense_float_tensor], [
              self._sparse_float_indices1, self._sparse_float_indices2
          ], [self._sparse_float_values1, self._sparse_float_values2],
          [self._sparse_float_shape1,
           self._sparse_float_shape2], [self._sparse_int_indices1],
          [self._sparse_int_values1], [self._sparse_int_shape1])

      self.assertAllEqual([0, 0], result.eval())
  def testWithEmptyEnsemble(self):
    with self.test_session():
      # Create an empty ensemble.
      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0, tree_ensemble_config="", name="empty")

      # Create zero feature importance.
      feature_usage_counts = variables.Variable(
          initial_value=array_ops.zeros([1], dtypes.int64),
          name="feature_usage_counts",
          trainable=False)
      feature_gains = variables.Variable(
          initial_value=array_ops.zeros([1], dtypes.float32),
          name="feature_gains",
          trainable=False)

      resources.initialize_resources(resources.shared_resources()).run()
      variables.initialize_all_variables().run()

      with ops.control_dependencies([
          ensemble_optimizer_ops.add_trees_to_ensemble(
              tree_ensemble_handle,
              self._ensemble_to_add.SerializeToString(),
              feature_usage_counts, [2],
              feature_gains, [0.4], [[]],
              learning_rate=1.0)
      ]):
        result = model_ops.tree_ensemble_serialize(tree_ensemble_handle)[1]

      # Output.
      output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig()
      output_ensemble.ParseFromString(result.eval())
      self.assertProtoEquals(self._tree_to_add, output_ensemble.trees[0])
      self.assertEqual(1, len(output_ensemble.trees))

      self.assertAllEqual([1.0], output_ensemble.tree_weights)

      self.assertEqual(1,
                       output_ensemble.tree_metadata[0].num_tree_weight_updates)

      self.assertAllEqual([2], feature_usage_counts.eval())
      self.assertArrayNear([0.4], feature_gains.eval(), 1e-6)
  def testWithEmptyEnsembleAndShrinkage(self):
    with self.test_session():
      # Add shrinkage config.
      learning_rate = 0.0001
      tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig()
      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=0,
          tree_ensemble_config=tree_ensemble.SerializeToString(),
          name="existing")

      # Create zero feature importance.
      feature_usage_counts = variables.Variable(
          initial_value=np.array([0, 0], np.int64),
          name="feature_usage_counts",
          trainable=False)
      feature_gains = variables.Variable(
          initial_value=np.array([0.0, 0.0], np.float32),
          name="feature_gains",
          trainable=False)

      resources.initialize_resources(resources.shared_resources()).run()
      variables.initialize_all_variables().run()

      output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig()
      with ops.control_dependencies([
          ensemble_optimizer_ops.add_trees_to_ensemble(
              tree_ensemble_handle,
              self._ensemble_to_add.SerializeToString(),
              feature_usage_counts, [1, 2],
              feature_gains, [0.5, 0.3], [[], []],
              learning_rate=learning_rate)
      ]):
        output_ensemble.ParseFromString(
            model_ops.tree_ensemble_serialize(tree_ensemble_handle)[1].eval())

      # New tree is added with shrinkage weight.
      self.assertAllClose([learning_rate], output_ensemble.tree_weights)
      self.assertEqual(1,
                       output_ensemble.tree_metadata[0].num_tree_weight_updates)
      self.assertAllEqual([1, 2], feature_usage_counts.eval())
      self.assertArrayNear([0.5 * learning_rate, 0.3 * learning_rate],
                           feature_gains.eval(), 1e-6)
Example #31
0
    def testStreamingQuantileBuckets(self):
        """Sets up the quantile summary op test as follows.

    Create a batch of 6 examples having a dense and sparse features.
    The data looks like this
    | Instance | instance weights | Dense 0
    | 0        |     10           |   1
    | 1        |     1            |   2
    | 2        |     1            |   3
    | 3        |     1            |   4
    | 4        |     1            |   4
    | 5        |     1            |   5
    """
        dense_float_tensor_0 = np.array([1, 2, 3, 4, 4, 5])
        example_weights = np.array([10, 1, 1, 1, 1, 1])

        with self.test_session() as sess:
            accumulator = quantile_ops.QuantileAccumulator(init_stamp_token=0,
                                                           num_quantiles=3,
                                                           epsilon=0.33,
                                                           name="q1")

            resources.initialize_resources(resources.shared_resources()).run()

            are_ready_noflush, _, = (accumulator.get_buckets(stamp_token=0))

            update = accumulator.add_summary(stamp_token=0,
                                             column=dense_float_tensor_0,
                                             example_weights=example_weights)
            with ops.control_dependencies([are_ready_noflush, update]):
                reset = accumulator.flush(stamp_token=0, next_stamp_token=1)
            with ops.control_dependencies([reset]):
                are_ready_flush, buckets = (accumulator.get_buckets(
                    stamp_token=1))
            buckets, are_ready_noflush, are_ready_flush = (sess.run(
                [buckets, are_ready_noflush, are_ready_flush]))
            self.assertEqual(False, are_ready_noflush)
            self.assertEqual(True, are_ready_flush)
            self.assertAllEqual([1, 3, 5], buckets)
Example #32
0
  def testSaveRestoreAfterFlush(self):
    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

    with self.test_session(graph=ops.Graph()) as sess:
      accumulator = quantile_ops.QuantileAccumulator(
          init_stamp_token=0, num_quantiles=3, epsilon=0.33, name="q0")

      save = saver.Saver()
      resources.initialize_resources(resources.shared_resources()).run()

      example_weights = constant_op.constant(
          [10, 1, 1, 1, 1, 1], dtype=dtypes.float32, shape=[6, 1])
      dense_float_tensor_0 = constant_op.constant(
          [1, 2, 3, 4, 4, 5], dtype=dtypes.float32, shape=[6, 1])
      update = accumulator.add_summary(
          stamp_token=0,
          column=dense_float_tensor_0,
          example_weights=example_weights)
      update.run()
      reset = accumulator.flush(stamp_token=0, next_stamp_token=1)
      with ops.control_dependencies([reset]):
        are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
      buckets, are_ready_flush = (sess.run([buckets, are_ready_flush]))
      self.assertEqual(True, are_ready_flush)
      self.assertAllEqual([1, 3, 5], buckets)
      save.save(sess, save_path)

    with self.test_session(graph=ops.Graph()) as sess:
      accumulator = quantile_ops.QuantileAccumulator(
          init_stamp_token=0, num_quantiles=3, epsilon=0.33, name="q0")
      save = saver.Saver()

      # Restore the saved values in the parameter nodes.
      save.restore(sess, save_path)
      are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
      buckets, are_ready_flush = (sess.run([buckets, are_ready_flush]))
      self.assertEqual(True, are_ready_flush)
      self.assertAllEqual([1, 3, 5], buckets)
Example #33
0
  def testBasicQuantileBucketsSingleResource(self):
    with self.cached_session() as sess:
      quantile_accumulator_handle = self.create_resource("floats", self.eps,
                                                         self.max_elements, 2)
      resources.initialize_resources(resources.shared_resources()).run()
      summaries = boosted_trees_ops.make_quantile_summaries(
          [self._feature_0, self._feature_1], self._example_weights,
          epsilon=self.eps)
      summary_op = boosted_trees_ops.quantile_add_summaries(
          quantile_accumulator_handle, summaries)
      flush_op = boosted_trees_ops.quantile_flush(
          quantile_accumulator_handle, self.num_quantiles)
      buckets = boosted_trees_ops.get_bucket_boundaries(
          quantile_accumulator_handle, num_features=2)
      quantiles = boosted_trees_ops.boosted_trees_bucketize(
          [self._feature_0, self._feature_1], buckets)
      sess.run(summary_op)
      sess.run(flush_op)
      self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
      self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())

      self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
      self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
Example #34
0
    def set_parameter(self, param):
        for name in self.default_param:
            if name not in param:
                param[name] = self.default_param[name]

        self.build_model()
        num_trees = param['num_trees']
        max_nodes = param['max_nodes']

        # Random Forest Parameters
        self.hparams = tensor_forest.ForestHParams(
            num_classes=self.class_num,
            num_features=self.feature_num,
            num_trees=num_trees,
            max_nodes=max_nodes).fill()

        # Build the Random Forest
        self.forest_graph = tensor_forest.RandomForestGraphs(self.hparams)
        # Get training graph and loss
        self.train_op = self.forest_graph.training_graph(
            self.inputs, self.labels)
        self.loss = self.forest_graph.training_loss(self.inputs, self.labels)

        # Measure the accuracy
        self.infer_op, _, _ = self.forest_graph.inference_graph(self.inputs)
        self.correct_prediction = tf.equal(tf.argmax(self.infer_op, 1),
                                           tf.cast(self.labels, tf.int64))
        self.accuracy = tf.reduce_mean(
            tf.cast(self.correct_prediction, tf.float32))

        #metrics = [self.get_metric(metric) for metric in param["metrics"]]
        #self.metrics = [metric_fun(self.output, self.ground_truth) for metric_fun in metrics]
        self.init_vars = tf.group(
            tf.global_variables_initializer(),
            resources.initialize_resources(resources.shared_resources()))
        self.batch_size = param["batch_size"]
        self.num_epochs = param["num_epochs"]
Example #35
0
    def testSaveRestoreAfterFlush(self):
        save_dir = os.path.join(self.get_temp_dir(), "save_restore")
        save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

        with self.cached_session() as sess:
            accumulator = boosted_trees_ops.QuantileAccumulator(
                num_streams=2,
                num_quantiles=self.num_quantiles,
                epsilon=self.eps,
                name="q0")

            save = saver.Saver()
            resources.initialize_resources(resources.shared_resources()).run()

            buckets = accumulator.get_bucket_boundaries()
            self.assertAllClose([], buckets[0].eval())
            self.assertAllClose([], buckets[1].eval())
            summaries = accumulator.add_summaries(
                [self._feature_0, self._feature_1], self._example_weights)
            with ops.control_dependencies([summaries]):
                flush = accumulator.flush()
            self.evaluate(flush)
            self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
            self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
            save.save(sess, save_path)

        with self.session(graph=ops.Graph()) as sess:
            accumulator = boosted_trees_ops.QuantileAccumulator(
                num_streams=2,
                num_quantiles=self.num_quantiles,
                epsilon=self.eps,
                name="q0")
            save = saver.Saver()
            save.restore(sess, save_path)
            buckets = accumulator.get_bucket_boundaries()
            self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
            self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
Example #36
0
    def _testStreamingQuantileBucketsHelper(self,
                                            inputs,
                                            num_quantiles=3,
                                            expected_buckets=None):
        """Helper to test quantile buckets on different inputs."""

        # set generate_quantiles to True since the test will generate fewer
        # boundaries otherwise.
        with self.test_session() as sess:
            accumulator = quantile_ops.QuantileAccumulator(
                init_stamp_token=0,
                num_quantiles=num_quantiles,
                epsilon=0.001,
                name="q1",
                generate_quantiles=True)
            resources.initialize_resources(resources.shared_resources()).run()
        input_column = array_ops.placeholder(dtypes.float32)
        weights = array_ops.placeholder(dtypes.float32)
        update = accumulator.add_summary(stamp_token=0,
                                         column=input_column,
                                         example_weights=weights)

        with self.test_session() as sess:
            sess.run(update, {
                input_column: inputs,
                weights: [1] * len(inputs)
            })

        with self.test_session() as sess:
            sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1))
            are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1))
            buckets, are_ready_flush = (sess.run([buckets, are_ready_flush]))
            self.assertEqual(True, are_ready_flush)
            # By default, use 3 quantiles, 4 boundaries for simplicity.
            self.assertEqual(num_quantiles + 1, len(buckets))
            if expected_buckets:
                self.assertAllEqual(buckets, expected_buckets)
Example #37
0
# Build the Random Forest
forest_graph = tensor_forest.RandomForestGraphs(hparams)
# Get training graph and loss
train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X, Y)

# Measure the accuracy
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Initialize the variables (i.e. assign their default value) and forest resources
init_vars = tf.group(
    tf.global_variables_initializer(),
    resources.initialize_resources(resources.shared_resources()))

# Start TensorFlow session
sess = tf.train.MonitoredSession()

# Run the initializer
sess.run(init_vars)

# Training
for i in range(1, num_steps + 1):
    # Prepare Data
    # Get the next batch of MNIST data (only images are needed, not labels)
    #  batch_x, batch_y = mnist.train.next_batch(batch_size)
    _, l = sess.run([train_op, loss_op], feed_dict={X: input_x, Y: input_y})
    if i % 50 == 0 or i == 1:
        acc = sess.run(accuracy_op, feed_dict={X: X_train, Y: y_train})
Example #38
0
    def testPredictionMultipleTree(self):
        """Tests the predictions work when we have multiple trees."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_weights: 0.1
        tree_weights: 0.2
        tree_weights: 1.0
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [11, 27]

            # Example 1: tree 0: 1.14, tree 1: 5.0, tree 2: 5.0 = >
            #            logit = 0.1*5.0+0.2*5.0+1*5
            # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
            #            logit= 0.1*1.14+0.2*7.0-1*7.0
            expected_logits = [[6.114], [-5.486]]

            # Do with parallelization, e.g. EVAL
            predict_op = boosted_trees_ops.predict(
                tree_ensemble_handle,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits = session.run(predict_op)
            self.assertAllClose(expected_logits, logits)

            # Do without parallelization, e.g. INFER - the result is the same
            predict_op = boosted_trees_ops.predict(
                tree_ensemble_handle,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits = session.run(predict_op)
            self.assertAllClose(expected_logits, logits)
Example #39
0
    def testCachedPredictionTheWholeTreeWasPruned(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 0.00
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: -6.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 5.0
          }
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

            # Create existing ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            cached_tree_ids = [
                0,
                0,
            ]
            # The predictions were cached in 1 and 2, both were pruned to the root.
            cached_node_ids = [1, 2]

            # We have two features: 0 and 1.These are not going to be used anywhere.
            feature_0_values = [12, 17]
            feature_1_values = [12, 12]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the last tree.
            self.assertAllClose([0, 0], new_tree_ids)
            self.assertAllClose([0, 0], new_node_ids)

            self.assertAllClose([[-6.0], [5.0]], logits_updates)
Example #40
0
    def testNoCachedPredictionButTreeExists(self):
        """Tests that predictions are updated once trees are added."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 1
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, none were cached before.
            cached_tree_ids = [0, 0]
            cached_node_ids = [0, 0]

            feature_0_values = [67, 5]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the first tree.
            self.assertAllClose([0, 0], new_tree_ids)
            self.assertAllClose([2, 1], new_node_ids)
            self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
Example #41
0
  def testCategoricalSplits(self):
    """Tests the training prediction work for categorical splits."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            categorical_split {
              feature_id: 1
              value: 2
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            categorical_split {
              feature_id: 0
              value: 13
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          is_finalized: true
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_0_values = [13, 1, 3]
      feature_1_values = [2, 2, 1]

      # No previous cached values.
      cached_tree_ids = [0, 0, 0]
      cached_node_ids = [0, 0, 0]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      self.assertAllClose([0, 0, 0], new_tree_ids)
      self.assertAllClose([3, 4, 2], new_node_ids)
      self.assertAllClose([[5.], [6.], [7.]], logits_updates)
Example #42
0
    def testRestore(self):
        # Calling self.test_session() without a graph specified results in
        # TensorFlowTestCase caching the session and returning the same one
        # every time. In this test, we need to create two different sessions
        # which is why we also create a graph and pass it to self.test_session()
        # to ensure no caching occurs under the hood.
        save_path = os.path.join(self.get_temp_dir(), "restore-test")
        with ops.Graph().as_default() as graph:
            with self.test_session(graph) as sess:
                tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig(
                )

                tree = tree_ensemble_config.trees.add()
                tree_ensemble_config.tree_metadata.add().is_finalized = True
                tree_ensemble_config.tree_weights.append(1.0)
                _append_to_leaf(tree.nodes.add().leaf, 0, -0.1)

                tree_ensemble_config2 = tree_config_pb2.DecisionTreeEnsembleConfig(
                )
                tree2 = tree_ensemble_config2.trees.add()
                tree_ensemble_config.tree_weights.append(1.0)
                _append_to_leaf(tree2.nodes.add().leaf, 0, -1.0)

                tree_ensemble_config3 = tree_config_pb2.DecisionTreeEnsembleConfig(
                )
                tree3 = tree_ensemble_config3.trees.add()
                tree_ensemble_config.tree_weights.append(1.0)
                _append_to_leaf(tree3.nodes.add().leaf, 0, -10.0)

                # Prepare learner config.
                learner_config = learner_pb2.LearnerConfig()
                learner_config.num_classes = 2

                tree_ensemble_handle = model_ops.tree_ensemble_variable(
                    stamp_token=3,
                    tree_ensemble_config=tree_ensemble_config.
                    SerializeToString(),
                    name="restore_tree")
                feature_usage_counts = variables.Variable(
                    initial_value=array_ops.zeros([1], dtypes.int64),
                    name="feature_usage_counts",
                    trainable=False)
                feature_gains = variables.Variable(
                    initial_value=array_ops.zeros([1], dtypes.float32),
                    name="feature_gains",
                    trainable=False)

                resources.initialize_resources(
                    resources.shared_resources()).run()
                variables.initialize_all_variables().run()
                my_saver = saver.Saver()

                with ops.control_dependencies([
                        ensemble_optimizer_ops.add_trees_to_ensemble(
                            tree_ensemble_handle,
                            tree_ensemble_config2.SerializeToString(),
                            feature_usage_counts, [0],
                            feature_gains, [0], [[]],
                            learning_rate=1)
                ]):
                    result, _, _ = prediction_ops.gradient_trees_prediction(
                        tree_ensemble_handle,
                        self._seed, [self._dense_float_tensor], [
                            self._sparse_float_indices1,
                            self._sparse_float_indices2
                        ], [
                            self._sparse_float_values1,
                            self._sparse_float_values2
                        ],
                        [self._sparse_float_shape1, self._sparse_float_shape2],
                        [self._sparse_int_indices1],
                        [self._sparse_int_values1], [self._sparse_int_shape1],
                        learner_config=learner_config.SerializeToString(),
                        apply_dropout=False,
                        apply_averaging=False,
                        center_bias=False,
                        reduce_dim=True)
                self.assertAllClose([[-1.1], [-1.1]], result.eval())
                # Save before adding other trees.
                val = my_saver.save(sess, save_path)
                self.assertEqual(save_path, val)

                # Add more trees after saving.
                with ops.control_dependencies([
                        ensemble_optimizer_ops.add_trees_to_ensemble(
                            tree_ensemble_handle,
                            tree_ensemble_config3.SerializeToString(),
                            feature_usage_counts, [0],
                            feature_gains, [0], [[]],
                            learning_rate=1)
                ]):
                    result, _, _ = prediction_ops.gradient_trees_prediction(
                        tree_ensemble_handle,
                        self._seed, [self._dense_float_tensor], [
                            self._sparse_float_indices1,
                            self._sparse_float_indices2
                        ], [
                            self._sparse_float_values1,
                            self._sparse_float_values2
                        ],
                        [self._sparse_float_shape1, self._sparse_float_shape2],
                        [self._sparse_int_indices1],
                        [self._sparse_int_values1], [self._sparse_int_shape1],
                        learner_config=learner_config.SerializeToString(),
                        apply_dropout=False,
                        apply_averaging=False,
                        center_bias=False,
                        reduce_dim=True)
                self.assertAllClose(result.eval(), [[-11.1], [-11.1]])

        # Start a second session.  In that session the parameter nodes
        # have not been initialized either.
        with ops.Graph().as_default() as graph:
            with self.test_session(graph) as sess:
                tree_ensemble_handle = model_ops.tree_ensemble_variable(
                    stamp_token=0,
                    tree_ensemble_config="",
                    name="restore_tree")
                my_saver = saver.Saver()
                my_saver.restore(sess, save_path)
                result, _, _ = prediction_ops.gradient_trees_prediction(
                    tree_ensemble_handle,
                    self._seed, [self._dense_float_tensor],
                    [self._sparse_float_indices1, self._sparse_float_indices2],
                    [self._sparse_float_values1, self._sparse_float_values2],
                    [self._sparse_float_shape1, self._sparse_float_shape2],
                    [self._sparse_int_indices1], [self._sparse_int_values1],
                    [self._sparse_int_shape1],
                    learner_config=learner_config.SerializeToString(),
                    apply_dropout=False,
                    apply_averaging=False,
                    center_bias=False,
                    reduce_dim=True)
                # Make sure we only have the first and second tree.
                # The third tree was added after the save.
                self.assertAllClose(result.eval(), [[-1.1], [-1.1]])
Example #43
0

hparams = tensor_forest.ForestHParams(num_classes=n_classes, num_features=n_features, num_trees=n_trees,
                                      max_nodes=max_nodes, split_after_samples=30).fill()

forest_graph = tensor_forest.RandomForestGraphs(hparams)

train_op = forest_graph.training_graph(x, y)
loss_op = forest_graph.training_loss(x, y)

infer_op, _, _ = forest_graph.inference_graph(x)

auc = tf.metrics.auc(tf.cast(y, tf.int64), infer_op[:, 1])[1]


init_vars = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer(), resources.initialize_resources(resources.shared_resources()))

sess = tf.Session()

sess.run(init_vars)

batch_size = 1000

import numpy as np
indices = list(range(n_train))

def gen_batch(indices):
    np.random.shuffle(indices)
    for batch_i in range(int(n_train / batch_size)):
        batch_index = indices[batch_i*batch_size: (batch_i+1)*batch_size]
        yield X_train_enc[batch_index], Y_train[batch_index]
Example #44
0
  def initialize_local_state(self, tf_config=None):
    """Called by the CombineFnWrapper's __init__ method.

    This can be used to set non-pickleable local state.  It is used in
    conjunction with overriding __reduce__ so this state is not pickled.  This
    method must be called prior to any other method.

    Args:
      tf_config: (optional) A tf.ConfigProto
    """
    # stamp_token is used to commit the state of the qaccumulator. In
    # this case, the qaccumulator state is completely returned and stored
    # as part of quantile_state/summary in the combiner fn (i.e the summary is
    # extracted and stored outside the qaccumulator). So we don't use
    # the timestamp mechanism to signify progress in the qaccumulator state.
    stamp_token = 0

    # Create a new session with a new graph for quantile ops.
    self._session = tf.Session(graph=tf.Graph(), config=tf_config)
    with self._session.graph.as_default():
      with self._session.as_default():
        self._qaccumulator = quantile_ops.QuantileAccumulator(
            init_stamp_token=stamp_token,
            num_quantiles=self._num_quantiles,
            epsilon=self._epsilon,
            name='qaccumulator')
        resources.initialize_resources(resources.shared_resources()).run()

        # Create placeholder that will be used to provide input the
        # QuantileAccumulator.  Has shape (1, None) as this is what the
        # QuantileAccumulator accepts.
        self._add_summary_input = tf.placeholder(
            dtype=self._bucket_numpy_dtype, shape=[1, None])

        # Create op to update the accumulator with new input fed from
        # self._qaccumulator_input.
        self._add_summary_op = self._qaccumulator.add_summary(
            stamp_token=stamp_token,
            column=self._add_summary_input,
            # All weights are equal, and the weight vector is the
            # same length as the input.
            example_weights=tf.ones_like(self._add_summary_input))

        # Create op to add a prebuilt summary to the accumulator, and a
        # placeholder tensor to provide the input for this op.
        self._prebuilt_summary_input = tf.placeholder(
            dtype=tf.string, shape=[])
        self._add_prebuilt_summary_op = self._qaccumulator.add_prebuilt_summary(
            stamp_token=stamp_token,
            summary=self._prebuilt_summary_input)

        # Create op to flush summaries and return a summary representing the
        # summaries that were added the accumulator so far.
        self._flush_summary_op = self._qaccumulator.flush_summary(
            stamp_token=stamp_token,
            next_stamp_token=stamp_token)

        # Create ops to flush the accumulator and return approximate boundaries.
        self._flush_op = self._qaccumulator.flush(
            stamp_token=stamp_token,
            next_stamp_token=stamp_token)
        _, self._buckets_op = self._qaccumulator.get_buckets(
            stamp_token=stamp_token)

    # We generate an empty summary by calling self._flush_summary_op.
    # We cache this as some implementations may call create_accumulator for
    # every input, and it can be cached since it will always be the same and
    # immutable.
    self._empty_summary = self._session.run(self._flush_summary_op)
Example #45
0
 def testInfrenceFromRestoredModel(self):
     input_data = [
         [-1., 0.],
         [-1., 2.],  # node 1
         [1., 0.],
         [1., -2.]
     ]  # node 2
     expected_prediction = [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]
     hparams = tensor_forest.ForestHParams(num_classes=2,
                                           num_features=2,
                                           num_trees=1,
                                           max_nodes=1000,
                                           split_after_samples=25).fill()
     tree_weight = {
         'decisionTree': {
             'nodes': [{
                 'binaryNode': {
                     'rightChildId': 2,
                     'leftChildId': 1,
                     'inequalityLeftChildTest': {
                         'featureId': {
                             'id': '0'
                         },
                         'threshold': {
                             'floatValue': 0
                         }
                     }
                 }
             }, {
                 'leaf': {
                     'vector': {
                         'value': [{
                             'floatValue': 0.0
                         }, {
                             'floatValue': 1.0
                         }]
                     }
                 },
                 'nodeId': 1
             }, {
                 'leaf': {
                     'vector': {
                         'value': [{
                             'floatValue': 0.0
                         }, {
                             'floatValue': 1.0
                         }]
                     }
                 },
                 'nodeId': 2
             }]
         }
     }
     restored_tree_param = ParseDict(
         tree_weight, _tree_proto.Model()).SerializeToString()
     graph_builder = tensor_forest.RandomForestGraphs(
         hparams, [restored_tree_param])
     probs, paths, var = graph_builder.inference_graph(input_data)
     self.assertTrue(isinstance(probs, ops.Tensor))
     self.assertTrue(isinstance(paths, ops.Tensor))
     self.assertTrue(isinstance(var, ops.Tensor))
     with self.cached_session():
         variables.global_variables_initializer().run()
         resources.initialize_resources(resources.shared_resources()).run()
         self.assertEquals(probs.eval().shape, (4, 2))
         self.assertEquals(probs.eval().tolist(), expected_prediction)
Example #46
0
  def testCategoricalSplits(self):
    """Tests the predictions work for categorical splits."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            categorical_split {
              feature_id: 1
              value: 2
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            categorical_split {
              feature_id: 0
              value: 13
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.0
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_0_values = [13, 1, 3]
      feature_1_values = [2, 2, 1]

      expected_logits = [[5.], [6.], [7.]]

      # Prediction should work fine.
      predict_op = boosted_trees_ops.predict(
          tree_ensemble_handle,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits = session.run(predict_op)
      self.assertAllClose(expected_logits, logits)
def run_LAmbDA2(gamma, delta, tau, prc_cut, bs_prc, num_trees, max_nodes):
    global X, Y, Gnp, Dnp, train, test, prt, cv
    D = tf.cast(Dnp, tf.float32)
    G = tf.cast(Gnp, tf.float32)
    #optunity_it = optunity_it+1;
    num_trees = int(num_trees)
    max_nodes = int(max_nodes)
    prc_cut = int(np.ceil(prc_cut))
    print(
        "gamma=%.4f, delta=%.4f, tau=%.4f, prc_cut=%i, bs_prc=%.4f, num_trees=%i, max_nodes=%i"
        % (gamma, delta, tau, prc_cut, bs_prc, num_trees, max_nodes))
    input_feats = X.shape[1]
    num_labls = G.shape.as_list()
    output_feats = num_labls[1]
    #print(output_feats)
    num_labls = num_labls[0]
    rowsums = np.sum(Gnp, axis=1)
    train2 = resample(prc_cut, Y, Gnp, train, gamma)
    # Bug??
    bs = int(np.ceil(bs_prc * train2.size))
    xs = tf.placeholder(tf.float32, [None, input_feats])
    #ys = tf.placeholder(tf.float32, [None,num_labls])
    yin = tf.placeholder(tf.int32, [None])
    print("Vars loaded xs and ys created")
    hparams = tensor_forest.ForestHParams(num_classes=output_feats,
                                          num_features=input_feats,
                                          num_trees=num_trees,
                                          max_nodes=max_nodes).fill()
    print("Tensor forest hparams created")
    forest_graph = tensor_forest.RandomForestGraphs(hparams)
    print("Tensor forest graph created")
    train_op = forest_graph.training_graph(xs, yin)
    loss_op = forest_graph.training_loss(xs, yin)
    print("Loss and train ops created")
    predict, _, _ = forest_graph.inference_graph(xs)
    print("Tensor forest variables created through predict")
    accuracy_op = tf.reduce_mean(
        tf.reduce_sum(tf.square(tf.one_hot(yin, output_feats) - predict),
                      reduction_indices=[1]))
    print(
        tf.reduce_sum(tf.square(tf.one_hot(yin, output_feats) - predict),
                      reduction_indices=[1]))
    #predict = tf.one_hot(pred);
    print("Lambda specific variables created")
    # Creating training and testing steps
    G2 = np.copy(Gnp)
    G2[rowsums > 1, :] = 0
    YI = np.matmul(Y, G2)
    YIrs = np.sum(YI, axis=1)
    trainI = train2[np.in1d(train2, np.where(YIrs == 1))]
    print("data type trainI,", trainI.dtype)
    testI = test[np.in1d(test, np.where(YIrs == 1))]
    print("trainI testI created")
    #init_vars=tf.global_variables_initializer()
    init_vars = tf.group(
        tf.global_variables_initializer(),
        resources.initialize_resources(resources.shared_resources()))
    sess = tf.Session()
    sess.run(init_vars)
    print("Session started")
    #beep = sess.run(predict,feed_dict={xs:X[1:100,:]});
    #beep = sess.run(predict,feed_dict={xs:X[train2[0:bs],:]});
    tensor_trainI = {
        xs: X[trainI, :],
        yin: sess.run(tf.argmax(get_yi(rowsums, G2, Y[trainI, :]), axis=1))
    }
    print("tensor_trainI made")
    tensor_testI = {
        xs: X[testI, :],
        yin: sess.run(tf.argmax(get_yi(rowsums, G2, Y[testI, :]), axis=1))
    }
    print("tensor_testI made")
    tensor_train = {
        xs:
        X[train2[0:bs], :],
        yin:
        sess.run(
            tf.argmax(get_yn(
                sess.run(predict, feed_dict={xs: X[train2[0:bs], :]}),
                Y[train2[0:bs], :], delta, tau, output_feats),
                      axis=1))
    }
    print("tensor_train made")
    tensor_test = {
        xs:
        X[test, :],
        yin:
        sess.run(
            tf.argmax(get_yn(sess.run(predict, feed_dict={xs: X[test, :]}),
                             Y[test, :], delta, tau, output_feats),
                      axis=1))
    }
    print("tensor_test made")
    #**********************************
    #print("Loss and training steps created with sample tensors")
    # Setting params and initializing
    print("Beginning iterations")
    # Starting training iterations
    print(X.shape)
    for i in range(1, 101):
        if i < 50:
            sess.run(train_op, feed_dict=tensor_trainI)
            #print("ran train op")
            if i % 10 == 0:
                print(
                    str(sess.run(accuracy_op, feed_dict=tensor_trainI)) + ' ' +
                    str(sess.run(accuracy_op, feed_dict=tensor_testI)))
        else:
            sess.run(train_op, feed_dict=tensor_train)
            if i % 10 == 0:
                print(
                    str(sess.run(accuracy_op, feed_dict=tensor_train)) + ' ' +
                    str(sess.run(accuracy_op, feed_dict=tensor_test)))
            elif i % 10 == 0:
                np.random_shuffle(train2)
                tensor_train = {
                    xs:
                    X[train2[0:bs], :],
                    yin:
                    sess.run(
                        get_yn(
                            sess.run(predict,
                                     feed_dict={xs: X[train2[0:bs], :]}),
                            Y[train2[0:bs], :], delta, tau, output_feats))
                }
    if prt:
        blah = sess.run(predict, feed_dict=tensor_test)
        sio.savemat('preds_cv' + str(cv) + '.mat', {'preds': blah})
        sio.savemat('truth_cv' + str(cv) + '.mat', {'labels': Y[test, :]})
    acc = sess.run(accuracy_op, feed_dict=tensor_test)
    print(
        "loss1=%.4f, gamma=%.4f, delta=%.4f, tau=%.4f, prc_cut=%i, bs_prc=%.4f, num_trees=%i, max_nodes=%i"
        % (acc, gamma, delta, tau, prc_cut, bs_prc, num_trees, max_nodes))
    tf.reset_default_graph()
    return (acc)
    def testContribsMultipleTree(self):
        """Tests that the contribs work when we have multiple trees."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf: {scalar: 2.1}
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
            metadata {
              original_leaf: {scalar: 5.5}
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_weights: 0.1
        tree_weights: 0.2
        tree_weights: 1.0
        tree_metadata: {
          num_layers_grown: 1}
        tree_metadata: {
          num_layers_grown: 2}
        tree_metadata: {
          num_layers_grown: 1}
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [13, -29
                                ]  # Unused. Feature is not in above ensemble.
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            bias = 2.1 * 0.1  # Root node of tree_0.
            expected_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
            # example_0 :  (bias, 0.1 * 1.14, 0.2 * 5.5 + .114, 0.2 * 5. + .114,
            # 1.0 * 5.0 + 0.2 * 5. + .114)
            # example_1 :  (bias, 0.1 * 1.14, 0.2 * 7 + .114,
            # 1.0 * -7. + 0.2 * 7 + .114)
            expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114),
                                     (bias, 0.114, 1.514, -5.486))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)
Example #49
0
 def default_init_op():
     return control_flow_ops.group(
         variables.global_variables_initializer(),
         resources.initialize_resources(
             resources.shared_resources()))
Example #50
0
    def testCachedPredictionFromPreviousTree(self):
        """Tests the predictions work when we have cache from previous trees."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7
            }
          }
          nodes {
            leaf {
              scalar: 5
            }
          }
          nodes {
            leaf {
              scalar: 6
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: false
        }
        tree_weights: 0.1
        tree_weights: 0.1
        tree_weights: 0.1
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 2.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 0]

            # We have two features: 0 and 1.
            feature_0_values = [36, 32]
            feature_1_values = [11, 27]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)
            # Example 1 will get to node 3 in tree 1 and node 2 of tree 2
            # Example 2 will get to node 2 in tree 1 and node 1 of tree 2

            # We are in the last tree.
            self.assertAllClose([2, 2], new_tree_ids)
            # When using the full tree, the first example will end up in node 4,
            # the second in node 5.
            self.assertAllClose([2, 1], new_node_ids)
            # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = >
            #            change = 0.1*(5.0+5.0)
            # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
            #            change= 0.1(1.14+7.0-7.0)
            self.assertAllClose([[1], [0.114]], logits_updates)
Example #51
0
    def testSerialization(self):
        with ops.Graph().as_default() as graph:
            with self.test_session(graph):
                tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig(
                )
                # Bias tree only for second class.
                tree1 = tree_ensemble_config.trees.add()
                _append_to_leaf(tree1.nodes.add().leaf, 1, -0.2)

                tree_ensemble_config.tree_weights.append(1.0)

                # Depth 2 tree.
                tree2 = tree_ensemble_config.trees.add()
                tree_ensemble_config.tree_weights.append(1.0)
                _set_float_split(
                    tree2.nodes.add().sparse_float_binary_split_default_right.
                    split, 1, 4.0, 1, 2)
                _set_float_split(tree2.nodes.add().dense_float_binary_split, 0,
                                 9.0, 3, 4)
                _append_to_leaf(tree2.nodes.add().leaf, 0, 0.5)
                _append_to_leaf(tree2.nodes.add().leaf, 1, 1.2)
                _append_to_leaf(tree2.nodes.add().leaf, 0, -0.9)

                tree_ensemble_handle = model_ops.tree_ensemble_variable(
                    stamp_token=7,
                    tree_ensemble_config=tree_ensemble_config.
                    SerializeToString(),
                    name="saver_tree")
                stamp_token, serialized_config = model_ops.tree_ensemble_serialize(
                    tree_ensemble_handle)
                resources.initialize_resources(
                    resources.shared_resources()).run()
                self.assertEqual(stamp_token.eval(), 7)
                serialized_config = serialized_config.eval()

        with ops.Graph().as_default() as graph:
            with self.test_session(graph):
                tree_ensemble_handle2 = model_ops.tree_ensemble_variable(
                    stamp_token=9,
                    tree_ensemble_config=serialized_config,
                    name="saver_tree2")
                resources.initialize_resources(
                    resources.shared_resources()).run()

                # Prepare learner config.
                learner_config = learner_pb2.LearnerConfig()
                learner_config.num_classes = 3

                result, _, _ = prediction_ops.gradient_trees_prediction(
                    tree_ensemble_handle2,
                    self._seed, [self._dense_float_tensor],
                    [self._sparse_float_indices1, self._sparse_float_indices2],
                    [self._sparse_float_values1, self._sparse_float_values2],
                    [self._sparse_float_shape1, self._sparse_float_shape2],
                    [self._sparse_int_indices1], [self._sparse_int_values1],
                    [self._sparse_int_shape1],
                    learner_config=learner_config.SerializeToString(),
                    apply_dropout=False,
                    apply_averaging=False,
                    center_bias=False,
                    reduce_dim=True)

                # Re-serialize tree.
                stamp_token2, serialized_config2 = model_ops.tree_ensemble_serialize(
                    tree_ensemble_handle2)

                # The first example will get bias class 1 -0.2 from first tree and
                # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2],
                # the second example will get the same bias class 1 -0.2 and leaf 3
                # payload of class 1 1.2 hence [0.0, 1.0].
                self.assertEqual(stamp_token2.eval(), 9)

                # Class 2 does have scores in the leaf => it gets score 0.
                self.assertEqual(serialized_config2.eval(), serialized_config)
                self.assertAllClose(result.eval(), [[0.5, -0.2], [0, 1.0]])
Example #52
0
    def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id:0
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 5
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 0.5
              original_leaf {
                scalar: 0.0143
               }
            }
          }
          nodes {
            leaf {
              scalar: 0.0553
            }
          }
          nodes {
            leaf {
              scalar: 0.0783
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.55
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 3
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 2
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.07
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.083
          }
          post_pruned_nodes_meta {
            new_node_id: 3
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 4
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.22
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.57
          }
        }
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 2
          num_layers_attempted: 4
        }
      """, tree_ensemble_config)

            # Create existing ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            cached_tree_ids = [0, 0, 0, 0, 0, 0]
            # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
            # changed the ids to 3 and 4 respectively.
            cached_node_ids = [3, 4, 5, 6, 7, 8]

            # We have two features: 0 and 1.
            feature_0_values = [12, 17, 35, 36, 23, 11]
            feature_1_values = [12, 12, 17, 18, 123, 24]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the last tree.
            self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids)
            # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
            # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in
            # the root node.
            self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids)

            cached_values = [[0.08], [0.093], [0.0553], [0.0783],
                             [0.15 + 0.08], [0.5 + 0.08]]
            root = 0.55
            self.assertAllClose(
                [[root + 0.01], [root + 0.01], [root + 0.0553],
                 [root + 0.0783], [root + 0.01], [root + 0.01]],
                logits_updates + cached_values)
Example #53
0
    def testGenerateFeatureSplitCandidatesMulticlass(self):
        with self.test_session() as sess:
            # Batch size is 4, 2 gradients per each instance.
            gradients = array_ops.constant(
                [[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]],
                shape=[4, 2])
            # 2x2 matrix for each instance
            hessian_0 = [[0.12, 0.02], [0.3, 0.11]]
            hessian_1 = [[0.07, -0.2], [-0.5, 0.2]]
            hessian_2 = [[0.2, -0.23], [-0.8, 0.9]]
            hessian_3 = [[0.13, -0.3], [-1.5, 2.2]]
            hessians = array_ops.constant(
                [hessian_0, hessian_1, hessian_2, hessian_3])

            partition_ids = [0, 0, 0, 1]
            indices = [[0, 0], [0, 1], [2, 0], [3, 0]]
            values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64)

            hessians = array_ops.constant(
                [hessian_0, hessian_1, hessian_2, hessian_3])
            partition_ids = array_ops.constant([0, 0, 0, 1],
                                               dtype=dtypes.int32)

            gradient_shape = tensor_shape.TensorShape([2])
            hessian_shape = tensor_shape.TensorShape([2, 2])
            class_id = -1

            split_handler = categorical_split_handler.EqualitySplitHandler(
                l1_regularization=0.1,
                l2_regularization=1,
                tree_complexity_regularization=0,
                min_node_weight=0,
                sparse_int_column=sparse_tensor.SparseTensor(
                    indices, values, [4, 1]),
                feature_column_group_id=0,
                gradient_shape=gradient_shape,
                hessian_shape=hessian_shape,
                multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
                init_stamp_token=0)
            resources.initialize_resources(resources.shared_resources()).run()

            empty_gradients, empty_hessians = get_empty_tensors(
                gradient_shape, hessian_shape)
            example_weights = array_ops.ones([4, 1], dtypes.float32)

            update_1 = split_handler.update_stats_sync(
                0,
                partition_ids,
                gradients,
                hessians,
                empty_gradients,
                empty_hessians,
                example_weights,
                is_active=array_ops.constant([True, True]))
            with ops.control_dependencies([update_1]):
                are_splits_ready, partitions, gains, splits = (
                    split_handler.make_splits(0, 1, class_id))
                are_splits_ready, partitions, gains, splits = (sess.run(
                    [are_splits_ready, partitions, gains, splits]))
        self.assertTrue(are_splits_ready)
        self.assertAllEqual([0, 1], partitions)

        split_info = split_info_pb2.SplitInfo()
        split_info.ParseFromString(splits[0])

        left_child = split_info.left_child.vector
        right_child = split_info.right_child.vector
        split_node = split_info.split_node.categorical_id_binary_split
        # Each leaf has 2 element vector.
        self.assertEqual(2, len(left_child.value))
        self.assertEqual(2, len(right_child.value))
        self.assertEqual(1, split_node.feature_id)

        split_info.ParseFromString(splits[1])
        left_child = split_info.left_child.vector
        right_child = split_info.right_child.vector
        split_node = split_info.split_node.categorical_id_binary_split
        self.assertEqual(2, len(left_child.value))
        self.assertEqual(0, len(right_child.value))
        self.assertEqual(1, split_node.feature_id)
Example #54
0
    def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
        """Tests case when, after training, first tree contains only a bias node."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 1.72
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
            metadata {
              original_leaf: {scalar: 5.5}
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.
        tree_weights: 0.1
        tree_metadata: {
          num_layers_grown: 0
        }
        tree_metadata: {
          num_layers_grown: 1
        }
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [13, -29]  # Unused feature.
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            expected_feature_ids = ((2, 0), (2, ))
            # bias = 1.72 * 1.  # Root node of tree_0.
            # example_0 :  (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
            # example_1 :  (bias, 0.1 * 7. + bias )
            expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)
Example #55
0
    def testCachedPredictionIsCurrent(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 0.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 2]

            # We have two features: 0 and 1. Values don't matter because trees didn't
            # change.
            feature_0_values = [67, 5]
            feature_1_values = [9, 17]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # Nothing changed.
            self.assertAllClose(cached_tree_ids, new_tree_ids)
            self.assertAllClose(cached_node_ids, new_node_ids)
            self.assertAllClose([[0], [0]], logits_updates)
Example #56
0
    def testCachedPredictionFromTheSameTree(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 7
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 1.4
              original_leaf {
                scalar: 7.14
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 7
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 2.7
              original_leaf {
                scalar: -4.375
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
          nodes {
            leaf {
              scalar: -5.875
            }
          }
          nodes {
            leaf {
              scalar: -2.075
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 0.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 0]

            # We have two features: 0 and 1.
            feature_0_values = [67, 5]
            feature_1_values = [9, 17]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are still in the same tree.
            self.assertAllClose([0, 0], new_tree_ids)
            # When using the full tree, the first example will end up in node 4,
            # the second in node 5.
            self.assertAllClose([4, 5], new_node_ids)
            # Full predictions for each instance would be 8.79 and -5.875,
            # so an update from the previous cached values lr*(7.14 and -2) would be
            # 1.65 and -3.875, and then multiply them by 0.1 (lr)
            self.assertAllClose([[0.1 * 1.65], [0.1 * -3.875]], logits_updates)
Example #57
0
    def testSerializeDeserialize(self):
        with self.cached_session():
            # Initialize.
            ensemble = boosted_trees_ops.TreeEnsemble('ensemble',
                                                      stamp_token=5)
            resources.initialize_resources(resources.shared_resources()).run()
            (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
             nodes_range) = ensemble.get_states()
            self.assertEqual(5, self.evaluate(stamp_token))
            self.assertEqual(0, self.evaluate(num_trees))
            self.assertEqual(0, self.evaluate(num_finalized_trees))
            self.assertEqual(0, self.evaluate(num_attempted_layers))
            self.assertAllEqual([0, 1], self.evaluate(nodes_range))

            # Deserialize.
            ensemble_proto = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 75
              threshold: 21
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -1.4
            }
          }
          nodes {
            leaf {
              scalar: -0.6
            }
          }
          nodes {
            leaf {
              scalar: 0.165
            }
          }
        }
        tree_weights: 0.5
        tree_metadata {
          num_layers_grown: 4  # it's fake intentionally.
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 5
          last_layer_node_start: 3
          last_layer_node_end: 7
        }
      """, ensemble_proto)
            with ops.control_dependencies([
                    ensemble.deserialize(
                        stamp_token=3,
                        serialized_proto=ensemble_proto.SerializeToString())
            ]):
                (stamp_token, num_trees, num_finalized_trees,
                 num_attempted_layers, nodes_range) = ensemble.get_states()
            self.assertEqual(3, self.evaluate(stamp_token))
            self.assertEqual(1, self.evaluate(num_trees))
            # This reads from metadata, not really counting the layers.
            self.assertEqual(5, self.evaluate(num_attempted_layers))
            self.assertEqual(0, self.evaluate(num_finalized_trees))
            self.assertAllEqual([3, 7], self.evaluate(nodes_range))

            # Serialize.
            new_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
            new_stamp_token, new_serialized = ensemble.serialize()
            self.assertEqual(3, self.evaluate(new_stamp_token))
            new_ensemble_proto.ParseFromString(new_serialized.eval())
            self.assertProtoEquals(ensemble_proto, new_ensemble_proto)
Example #58
0
def main(args):
    vocab = build_vocab(args.data_path)
    data = pd.DataFrame({
        'label': vocab.labels,
        'lprox': vocab.lprox,
        'rprox': vocab.rprox,
        'x': vocab.x,
        'y': vocab.y,
        'z': vocab.z,
    })
    y = data['label']
    lprox = pd.DataFrame(data['lprox'].values.tolist())
    rprox = pd.DataFrame(data['rprox'].values.tolist())
    xax = pd.DataFrame(data['x'].values.tolist())
    yax = pd.DataFrame(data['y'].values.tolist())
    zax = pd.DataFrame(data['z'].values.tolist())

    X = pd.concat([lprox, rprox, xax, yax, zax], axis=1)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

    num_steps = 100  # Total steps to train
    num_classes = 2
    num_features = 585
    num_trees = 10
    max_nodes = 1000

    X = tf.placeholder(tf.float32, shape=[None, num_features])
    Y = tf.placeholder(tf.int64, shape=[None])

    hparams = tensor_forest.ForestHParams(num_classes=num_classes,
                                          num_features=num_features,
                                          num_trees=num_trees,
                                          max_nodes=max_nodes).fill()
    forest_graph = tensor_forest.RandomForestGraphs(hparams)
    train_op = forest_graph.training_graph(X, Y)

    loss_op = forest_graph.training_loss(X, Y)
    infer_op, _, _ = forest_graph.inference_graph(X)
    correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
    accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    init_vars = tf.group(
        tf.global_variables_initializer(),
        resources.initialize_resources(resources.shared_resources()))

    sess = tf.Session()
    sess.run(init_vars)

    for i in range(1, num_steps + 1):
        saver = tf.train.Saver()
        _, l = sess.run([train_op, loss_op],
                        feed_dict={
                            X: X_train,
                            Y: y_train
                        })
        if i % 50 == 0 or i == 1:
            acc = sess.run(accuracy_op, feed_dict={X: X_train, Y: y_train})
            save_path = saver.save(sess, 'models/model%i.ckpt' % (i))
            print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

    print("Test Accuracy:",
          sess.run(accuracy_op, feed_dict={
              X: X_test,
              Y: y_test
          }))
Example #59
0
 def testCreateWithProto(self):
     with self.cached_session():
         ensemble_proto = boosted_trees_pb2.TreeEnsemble()
         text_format.Merge(
             """
     trees {
       nodes {
         bucketized_split {
           feature_id: 4
           left_id: 1
           right_id: 2
         }
         metadata {
           gain: 7.62
         }
       }
       nodes {
         bucketized_split {
           threshold: 21
           left_id: 3
           right_id: 4
         }
         metadata {
           gain: 1.4
           original_leaf {
             scalar: 7.14
           }
         }
       }
       nodes {
         bucketized_split {
           feature_id: 1
           threshold: 7
           left_id: 5
           right_id: 6
         }
         metadata {
           gain: 2.7
           original_leaf {
             scalar: -4.375
           }
         }
       }
       nodes {
         leaf {
           scalar: 6.54
         }
       }
       nodes {
         leaf {
           scalar: 7.305
         }
       }
       nodes {
         leaf {
           scalar: -4.525
         }
       }
       nodes {
         leaf {
           scalar: -4.145
         }
       }
     }
     trees {
       nodes {
         bucketized_split {
           feature_id: 75
           threshold: 21
           left_id: 1
           right_id: 2
         }
         metadata {
           gain: -1.4
         }
       }
       nodes {
         leaf {
           scalar: -0.6
         }
       }
       nodes {
         leaf {
           scalar: 0.165
         }
       }
     }
     tree_weights: 0.15
     tree_weights: 1.0
     tree_metadata {
       num_layers_grown: 2
       is_finalized: true
     }
     tree_metadata {
       num_layers_grown: 1
       is_finalized: false
     }
     growing_metadata {
       num_trees_attempted: 2
       num_layers_attempted: 6
       last_layer_node_start: 16
       last_layer_node_end: 19
     }
   """, ensemble_proto)
         ensemble = boosted_trees_ops.TreeEnsemble(
             'ensemble',
             stamp_token=7,
             serialized_proto=ensemble_proto.SerializeToString())
         resources.initialize_resources(resources.shared_resources()).run()
         (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
          nodes_range) = ensemble.get_states()
         self.assertEqual(7, self.evaluate(stamp_token))
         self.assertEqual(2, self.evaluate(num_trees))
         self.assertEqual(1, self.evaluate(num_finalized_trees))
         self.assertEqual(6, self.evaluate(num_attempted_layers))
         self.assertAllEqual([16, 19], self.evaluate(nodes_range))
Example #60
0
    def testGenerateFeatureSplitCandidates(self):
        with self.test_session() as sess:
            # The data looks like the following:
            # Example |  Gradients    | Partition | Feature ID     |
            # i0      |  (0.2, 0.12)  | 0         | 1,2            |
            # i1      |  (-0.5, 0.07) | 0         |                |
            # i2      |  (1.2, 0.2)   | 0         | 2              |
            # i3      |  (4.0, 0.13)  | 1         | 1              |
            gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
            hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
            partition_ids = [0, 0, 0, 1]
            indices = [[0, 0], [0, 1], [2, 0], [3, 0]]
            values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64)

            gradient_shape = tensor_shape.scalar()
            hessian_shape = tensor_shape.scalar()
            class_id = -1

            split_handler = categorical_split_handler.EqualitySplitHandler(
                l1_regularization=0.1,
                l2_regularization=1,
                tree_complexity_regularization=0,
                min_node_weight=0,
                sparse_int_column=sparse_tensor.SparseTensor(
                    indices, values, [4, 1]),
                feature_column_group_id=0,
                gradient_shape=gradient_shape,
                hessian_shape=hessian_shape,
                multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
                init_stamp_token=0)
            resources.initialize_resources(resources.shared_resources()).run()

            empty_gradients, empty_hessians = get_empty_tensors(
                gradient_shape, hessian_shape)
            example_weights = array_ops.ones([4, 1], dtypes.float32)

            update_1 = split_handler.update_stats_sync(
                0,
                partition_ids,
                gradients,
                hessians,
                empty_gradients,
                empty_hessians,
                example_weights,
                is_active=array_ops.constant([True, True]))
            update_2 = split_handler.update_stats_sync(
                0,
                partition_ids,
                gradients,
                hessians,
                empty_gradients,
                empty_hessians,
                example_weights,
                is_active=array_ops.constant([True, True]))

            with ops.control_dependencies([update_1, update_2]):
                are_splits_ready, partitions, gains, splits = (
                    split_handler.make_splits(0, 1, class_id))
                are_splits_ready, partitions, gains, splits = (sess.run(
                    [are_splits_ready, partitions, gains, splits]))
        self.assertTrue(are_splits_ready)
        self.assertAllEqual([0, 1], partitions)

        # Check the split on partition 0.
        # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1)
        expected_left_weight = -0.9848484848484846

        # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1)
        expected_left_gain = 1.2803030303030298

        # -(-0.5 + 0.1) / (0.07 + 1)
        expected_right_weight = 0.37383177570093457

        # (-0.5 + 0.1) ** 2 / (0.07 + 1)
        expected_right_gain = 0.14953271028037385

        # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
        expected_bias_gain = 0.46043165467625885

        split_info = split_info_pb2.SplitInfo()
        split_info.ParseFromString(splits[0])
        left_child = split_info.left_child.vector
        right_child = split_info.right_child.vector
        split_node = split_info.split_node.categorical_id_binary_split

        self.assertEqual(0, split_node.feature_column)

        self.assertEqual(2, split_node.feature_id)

        self.assertAllClose(
            expected_left_gain + expected_right_gain - expected_bias_gain,
            gains[0], 0.00001)

        self.assertAllClose([expected_left_weight], left_child.value, 0.00001)

        self.assertAllClose([expected_right_weight], right_child.value,
                            0.00001)

        # Check the split on partition 1.
        # (-4 + 0.1) / (0.13 + 1)
        expected_left_weight = -3.4513274336283186
        # (-4 + 0.1) ** 2 / (0.13 + 1)
        expected_left_gain = 13.460176991150442
        expected_right_weight = 0
        expected_right_gain = 0
        # (-4 + 0.1) ** 2 / (0.13 + 1)
        expected_bias_gain = 13.460176991150442

        # Verify candidate for partition 1, there's only one active feature here
        # so zero gain is expected.
        split_info = split_info_pb2.SplitInfo()
        split_info.ParseFromString(splits[1])
        left_child = split_info.left_child.vector
        right_child = split_info.right_child.vector
        split_node = split_info.split_node.categorical_id_binary_split
        self.assertAllClose(0.0, gains[1], 0.00001)

        self.assertAllClose([expected_left_weight], left_child.value, 0.00001)

        self.assertAllClose([expected_right_weight], right_child.value,
                            0.00001)

        self.assertEqual(0, split_node.feature_column)

        self.assertEqual(1, split_node.feature_id)