Ejemplo n.º 1
0
  def testContribsForOnlyABiasNode(self):
    """Tests case when, after training, only left with a bias node.

    For example, this could happen if the final ensemble contains one tree that
    got pruned up to the root.
    """
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            leaf {
              scalar: 1.72
            }
          }
        }
        tree_weights: 0.1
        tree_metadata: {
          num_layers_grown: 0
        }
      """, tree_ensemble_config)

      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # All features are unused.
      feature_0_values = [36, 32]
      feature_1_values = [13, -29]
      feature_2_values = [11, 27]

      # Expected logits are computed by traversing the logit path and
      # subtracting child logits from parent logits.
      bias = 1.72 * 0.1  # Root node of tree_0.
      expected_feature_ids = ((), ())
      expected_logits_paths = ((bias,), (bias,))

      bucketized_features = [
          feature_0_values, feature_1_values, feature_2_values
      ]

      debug_op = boosted_trees_ops.example_debug_outputs(
          tree_ensemble_handle,
          bucketized_features=bucketized_features,
          logits_dimension=1)

      serialized_examples_debug_outputs = session.run(debug_op)
      feature_ids = []
      logits_paths = []
      for example in serialized_examples_debug_outputs:
        example_debug_outputs = boosted_trees_pb2.DebugOutput()
        example_debug_outputs.ParseFromString(example)
        feature_ids.append(example_debug_outputs.feature_ids)
        logits_paths.append(example_debug_outputs.logits_path)

      self.assertAllClose(feature_ids, expected_feature_ids)
      self.assertAllClose(logits_paths, expected_logits_paths)
Ejemplo n.º 2
0
 def _assert_checkpoint(self, model_dir, global_step, finalized_trees,
                        attempted_layers):
   reader = checkpoint_utils.load_checkpoint(model_dir)
   self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
   serialized = reader.get_tensor('boosted_trees:0_serialized')
   ensemble_proto = boosted_trees_pb2.TreeEnsemble()
   ensemble_proto.ParseFromString(serialized)
   self.assertEqual(
       finalized_trees,
       sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized]))
   self.assertEqual(attempted_layers,
                    ensemble_proto.growing_metadata.num_layers_attempted)
Ejemplo n.º 3
0
    def testContribsMultipleTree(self):
        """Tests that the contribs work when we have multiple trees."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf: {scalar: 2.1}
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
            metadata {
              original_leaf: {scalar: 5.5}
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_weights: 0.1
        tree_weights: 0.2
        tree_weights: 1.0
        tree_metadata: {
          num_layers_grown: 1}
        tree_metadata: {
          num_layers_grown: 2}
        tree_metadata: {
          num_layers_grown: 1}
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [13, -29
                                ]  # Unused. Feature is not in above ensemble.
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            bias = 2.1 * 0.1  # Root node of tree_0.
            expected_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
            # example_0 :  (bias, 0.1 * 1.14, 0.2 * 5.5 + .114, 0.2 * 5. + .114,
            # 1.0 * 5.0 + 0.2 * 5. + .114)
            # example_1 :  (bias, 0.1 * 1.14, 0.2 * 7 + .114,
            # 1.0 * -7. + 0.2 * 7 + .114)
            expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114),
                                     (bias, 0.114, 1.514, -5.486))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)
Ejemplo n.º 4
0
    def testSerializeDeserialize(self):
        with self.cached_session():
            # Initialize.
            ensemble = boosted_trees_ops.TreeEnsemble('ensemble',
                                                      stamp_token=5)
            resources.initialize_resources(resources.shared_resources()).run()
            (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
             nodes_range) = ensemble.get_states()
            self.assertEqual(5, self.evaluate(stamp_token))
            self.assertEqual(0, self.evaluate(num_trees))
            self.assertEqual(0, self.evaluate(num_finalized_trees))
            self.assertEqual(0, self.evaluate(num_attempted_layers))
            self.assertAllEqual([0, 1], self.evaluate(nodes_range))

            # Deserialize.
            ensemble_proto = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 75
              threshold: 21
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -1.4
            }
          }
          nodes {
            leaf {
              scalar: -0.6
            }
          }
          nodes {
            leaf {
              scalar: 0.165
            }
          }
        }
        tree_weights: 0.5
        tree_metadata {
          num_layers_grown: 4  # it's fake intentionally.
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 5
          last_layer_node_start: 3
          last_layer_node_end: 7
        }
      """, ensemble_proto)
            with ops.control_dependencies([
                    ensemble.deserialize(
                        stamp_token=3,
                        serialized_proto=ensemble_proto.SerializeToString())
            ]):
                (stamp_token, num_trees, num_finalized_trees,
                 num_attempted_layers, nodes_range) = ensemble.get_states()
            self.assertEqual(3, self.evaluate(stamp_token))
            self.assertEqual(1, self.evaluate(num_trees))
            # This reads from metadata, not really counting the layers.
            self.assertEqual(5, self.evaluate(num_attempted_layers))
            self.assertEqual(0, self.evaluate(num_finalized_trees))
            self.assertAllEqual([3, 7], self.evaluate(nodes_range))

            # Serialize.
            new_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
            new_stamp_token, new_serialized = ensemble.serialize()
            self.assertEqual(3, self.evaluate(new_stamp_token))
            new_ensemble_proto.ParseFromString(new_serialized.eval())
            self.assertProtoEquals(ensemble_proto, new_ensemble_proto)
Ejemplo n.º 5
0
  def testCategoricalSplits(self):
    """Tests the predictions work for categorical splits."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            categorical_split {
              feature_id: 1
              value: 2
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            categorical_split {
              feature_id: 0
              value: 13
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.0
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_0_values = [13, 1, 3]
      feature_1_values = [2, 2, 1]

      expected_logits = [[5.], [6.], [7.]]

      # Prediction should work fine.
      predict_op = boosted_trees_ops.predict(
          tree_ensemble_handle,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits = session.run(predict_op)
      self.assertAllClose(expected_logits, logits)
Ejemplo n.º 6
0
    def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
        """Tests case when, after training, first tree contains only a bias node."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 1.72
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
            metadata {
              original_leaf: {scalar: 5.5}
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.
        tree_weights: 0.1
        tree_metadata: {
          num_layers_grown: 0
        }
        tree_metadata: {
          num_layers_grown: 1
        }
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [13, -29]  # Unused feature.
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            expected_feature_ids = ((2, 0), (2, ))
            # bias = 1.72 * 1.  # Root node of tree_0.
            # example_0 :  (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
            # example_1 :  (bias, 0.1 * 7. + bias )
            expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)
Ejemplo n.º 7
0
    def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id:0
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 5
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 0.5
              original_leaf {
                scalar: 0.0143
               }
            }
          }
          nodes {
            leaf {
              scalar: 0.0553
            }
          }
          nodes {
            leaf {
              scalar: 0.0783
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.55
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 3
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 2
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.07
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.083
          }
          post_pruned_nodes_meta {
            new_node_id: 3
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 4
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.22
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.57
          }
        }
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 2
          num_layers_attempted: 4
        }
      """, tree_ensemble_config)

            # Create existing ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            cached_tree_ids = [0, 0, 0, 0, 0, 0]
            # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
            # changed the ids to 3 and 4 respectively.
            cached_node_ids = [3, 4, 5, 6, 7, 8]

            # We have two features: 0 and 1.
            feature_0_values = [12, 17, 35, 36, 23, 11]
            feature_1_values = [12, 12, 17, 18, 123, 24]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the last tree.
            self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids)
            # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
            # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in
            # the root node.
            self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids)

            cached_values = [[0.08], [0.093], [0.0553], [0.0783],
                             [0.15 + 0.08], [0.5 + 0.08]]
            root = 0.55
            self.assertAllClose(
                [[root + 0.01], [root + 0.01], [root + 0.0553],
                 [root + 0.0783], [root + 0.01], [root + 0.01]],
                logits_updates + cached_values)
Ejemplo n.º 8
0
  def testPostPruningChangesNothing(self):
    """Test growing an ensemble with post-pruning with all gains >0."""
    with self.test_session() as session:
      # Create empty ensemble.
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle

      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare inputs.
      # Second feature has larger (but still negative gain).
      feature_ids = [3, 4]

      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([7.62], dtype=np.float32)
      feature1_thresholds = np.array([52], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)

      feature2_nodes = np.array([0], dtype=np.int32)
      feature2_gains = np.array([0.63], dtype=np.float32)
      feature2_thresholds = np.array([23], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=1,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes],
          gains=[feature1_gains, feature2_gains],
          thresholds=[feature1_thresholds, feature2_thresholds],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs
          ])

      session.run(grow_op)

      # Expect the split from the first feature to be chosen.
      # Pruning got triggered but changed nothing.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 3
              threshold: 52
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
          nodes {
            leaf {
              scalar: 7.143
            }
          }
        }
        trees {
          nodes {
            leaf {
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, res_ensemble)
Ejemplo n.º 9
0
  def testMetadataWhenCantSplitDuePrePruning(self):
    """Test metadata is updated correctly when no split due to prepruning."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.
      feature_ids = [0, 1, 0]

      # All the gains are negative.
      feature1_nodes = np.array([1], dtype=np.int32)
      feature1_gains = np.array([-1.4], dtype=np.float32)
      feature1_thresholds = np.array([21], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)

      feature2_nodes = np.array([1, 2], dtype=np.int32)
      feature2_gains = np.array([-0.63, -2.7], dtype=np.float32)
      feature2_thresholds = np.array([23, 7], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)

      feature3_nodes = np.array([2], dtype=np.int32)
      feature3_gains = np.array([-2.8], dtype=np.float32)
      feature3_thresholds = np.array([3], dtype=np.int32)
      feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
      feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
          gains=[feature1_gains, feature2_gains, feature3_gains],
          thresholds=[
              feature1_thresholds, feature2_thresholds, feature3_thresholds
          ],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs,
              feature3_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs,
              feature3_right_node_contribs
          ])
      session.run(grow_op)

      # Expect that no new split was created because all the gains were negative
      # Global metadata should be updated, tree metadata should not be updated.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
Ejemplo n.º 10
0
    def testCachedPredictionFromTheSameTree(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 7
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 1.4
              original_leaf {
                scalar: 7.14
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 7
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 2.7
              original_leaf {
                scalar: -4.375
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
          nodes {
            leaf {
              scalar: -5.875
            }
          }
          nodes {
            leaf {
              scalar: -2.075
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 0.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 0]

            # We have two features: 0 and 1.
            feature_0_values = [67, 5]
            feature_1_values = [9, 17]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are still in the same tree.
            self.assertAllClose([0, 0], new_tree_ids)
            # When using the full tree, the first example will end up in node 4,
            # the second in node 5.
            self.assertAllClose([4, 5], new_node_ids)
            # Full predictions for each instance would be 8.79 and -5.875,
            # so an update from the previous cached values lr*(7.14 and -2) would be
            # 1.65 and -3.875, and then multiply them by 0.1 (lr)
            self.assertAllClose([[0.1 * 1.65], [0.1 * -3.875]], logits_updates)
Ejemplo n.º 11
0
  def testPrePruning(self):
    """Test growing an existing ensemble with pre-pruning."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.
      # For node 1, the best split is on feature 2 (gain -0.63), but the gain
      # is negative so node 1 will not be split.
      # For node 2, the best split is on feature 3, gain is positive.

      feature_ids = [0, 1, 0]

      feature1_nodes = np.array([1], dtype=np.int32)
      feature1_gains = np.array([-1.4], dtype=np.float32)
      feature1_thresholds = np.array([21], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)

      feature2_nodes = np.array([1, 2], dtype=np.int32)
      feature2_gains = np.array([-0.63, 2.7], dtype=np.float32)
      feature2_thresholds = np.array([23, 7], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)

      feature3_nodes = np.array([2], dtype=np.int32)
      feature3_gains = np.array([2.8], dtype=np.float32)
      feature3_thresholds = np.array([3], dtype=np.int32)
      feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
      feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
          gains=[feature1_gains, feature2_gains, feature3_gains],
          thresholds=[
              feature1_thresholds, feature2_thresholds, feature3_thresholds
          ],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs,
              feature3_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs,
              feature3_right_node_contribs
          ])
      session.run(grow_op)

      # Expect the split for node 1 to be chosen from feature 1 and
      # the split for node 2 to be chosen from feature 2.
      # The grown tree should not be finalized as max tree depth is 3 and
      # it's only grown 2 layers.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 3
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 2.8
              original_leaf {
                scalar: -4.375
              }
            }
          }
          nodes {
            leaf {
              scalar: -4.45
            }
          }
          nodes {
            leaf {
              scalar: -4.182
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: false
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
Ejemplo n.º 12
0
  def testMetadataWhenCantSplitDueToEmptySplits(self):
    """Test that the metadata is updated even though we can't split."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 0.714
            }
          }
          nodes {
            leaf {
              scalar: -0.4375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.
      # feature 1 only has a candidate for node 1, feature 2 has candidates
      # for both nodes and feature 3 only has a candidate for node 2.

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
          max_depth=2,
          # No splits are available.
          feature_ids=[],
          node_ids=[],
          gains=[],
          thresholds=[],
          left_node_contribs=[],
          right_node_contribs=[])
      session.run(grow_op)

      # Expect no new splits created, but attempted (global) stats updated. Meta
      # data for this tree should not be updated (we didn't succeed building a
      # layer.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 0.714
            }
          }
          nodes {
            leaf {
              scalar: -0.4375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
Ejemplo n.º 13
0
  def testGrowExistingEnsembleTreeFinalized(self):
    """Test growing an existing ensemble with the last tree finalized."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.0
            }
          }
        }
        tree_weights: 0.15
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.

      feature_ids = [75]

      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([-1.4], dtype=np.float32)
      feature1_thresholds = np.array([21], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
          learning_rate=0.1,
          max_depth=2,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes],
          gains=[feature1_gains],
          thresholds=[feature1_thresholds],
          left_node_contribs=[feature1_left_node_contribs],
          right_node_contribs=[feature1_right_node_contribs])
      session.run(grow_op)

      # Expect a new tree added, with a split on feature 75
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
       trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 7.14
            }
          }
          nodes {
            leaf {
              scalar: -4.375
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 75
              threshold: 21
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -1.4
            }
          }
          nodes {
            leaf {
              scalar: -0.6
            }
          }
          nodes {
            leaf {
              scalar: 0.165
            }
          }
        }
        tree_weights: 0.15
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
        }
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 2
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
Ejemplo n.º 14
0
  def testGrowWithEmptyEnsemble(self):
    """Test growing an empty ensemble."""
    with self.test_session() as session:
      # Create empty ensemble.
      tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_ids = [0, 2, 6]

      # Prepare feature inputs.
      # Note that features 1 & 3 have the same gain but different splits.
      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([7.62], dtype=np.float32)
      feature1_thresholds = np.array([52], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)

      feature2_nodes = np.array([0], dtype=np.int32)
      feature2_gains = np.array([0.63], dtype=np.float32)
      feature2_thresholds = np.array([23], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24]], dtype=np.float32)

      # Feature split with the highest gain.
      feature3_nodes = np.array([0], dtype=np.int32)
      feature3_gains = np.array([7.65], dtype=np.float32)
      feature3_thresholds = np.array([7], dtype=np.int32)
      feature3_left_node_contribs = np.array([[-4.89]], dtype=np.float32)
      feature3_right_node_contribs = np.array([[5.3]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
          # Tree will be finalized now, since we will reach depth 1.
          max_depth=1,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
          gains=[feature1_gains, feature2_gains, feature3_gains],
          thresholds=[
              feature1_thresholds, feature2_thresholds, feature3_thresholds
          ],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs,
              feature3_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs,
              feature3_right_node_contribs
          ])
      session.run(grow_op)

      new_stamp, serialized = session.run(tree_ensemble.serialize())

      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      # Note that since the tree is finalized, we added a new dummy tree.
      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 6
              threshold: 7
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.65
            }
          }
          nodes {
            leaf {
              scalar: -0.489
            }
          }
          nodes {
            leaf {
              scalar: 0.53
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.0
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
Ejemplo n.º 15
0
  def testGrowExistingEnsembleTreeNotFinalized(self):
    """Test growing an existing ensemble with the last tree not finalized."""
    with self.test_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 0.714
            }
          }
          nodes {
            leaf {
              scalar: -0.4375
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare feature inputs.
      # feature 1 only has a candidate for node 1, feature 2 has candidates
      # for both nodes and feature 3 only has a candidate for node 2.

      feature_ids = [0, 1, 0]

      feature1_nodes = np.array([1], dtype=np.int32)
      feature1_gains = np.array([1.4], dtype=np.float32)
      feature1_thresholds = np.array([21], dtype=np.int32)
      feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)

      feature2_nodes = np.array([1, 2], dtype=np.int32)
      feature2_gains = np.array([0.63, 2.7], dtype=np.float32)
      feature2_thresholds = np.array([23, 7], dtype=np.int32)
      feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)

      feature3_nodes = np.array([2], dtype=np.int32)
      feature3_gains = np.array([1.7], dtype=np.float32)
      feature3_thresholds = np.array([3], dtype=np.int32)
      feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
      feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=0.1,
          pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
          # tree is going to be finalized now, since we reach depth 2.
          max_depth=2,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
          gains=[feature1_gains, feature2_gains, feature3_gains],
          thresholds=[
              feature1_thresholds, feature2_thresholds, feature3_thresholds
          ],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs,
              feature3_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs,
              feature3_right_node_contribs
          ])
      session.run(grow_op)

      # Expect the split for node 1 to be chosen from feature 1 and
      # the split for node 2 to be chosen from feature 2.
      # The grown tree should be finalized as max tree depth is 2 and we have
      # grown 2 layers.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      tree_ensemble = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 4
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            bucketized_split {
              threshold: 21
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 1.4
              original_leaf {
                scalar: 0.714
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 7
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 2.7
              original_leaf {
                scalar: -0.4375
              }
            }
          }
          nodes {
            leaf {
              scalar: 0.114
            }
          }
          nodes {
            leaf {
              scalar: 0.879
            }
          }
          nodes {
            leaf {
              scalar: -0.5875
            }
          }
          nodes {
            leaf {
              scalar: -0.2075
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.0
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, tree_ensemble)
Ejemplo n.º 16
0
 def testCreateWithProto(self):
     with self.cached_session():
         ensemble_proto = boosted_trees_pb2.TreeEnsemble()
         text_format.Merge(
             """
     trees {
       nodes {
         bucketized_split {
           feature_id: 4
           left_id: 1
           right_id: 2
         }
         metadata {
           gain: 7.62
         }
       }
       nodes {
         bucketized_split {
           threshold: 21
           left_id: 3
           right_id: 4
         }
         metadata {
           gain: 1.4
           original_leaf {
             scalar: 7.14
           }
         }
       }
       nodes {
         bucketized_split {
           feature_id: 1
           threshold: 7
           left_id: 5
           right_id: 6
         }
         metadata {
           gain: 2.7
           original_leaf {
             scalar: -4.375
           }
         }
       }
       nodes {
         leaf {
           scalar: 6.54
         }
       }
       nodes {
         leaf {
           scalar: 7.305
         }
       }
       nodes {
         leaf {
           scalar: -4.525
         }
       }
       nodes {
         leaf {
           scalar: -4.145
         }
       }
     }
     trees {
       nodes {
         bucketized_split {
           feature_id: 75
           threshold: 21
           left_id: 1
           right_id: 2
         }
         metadata {
           gain: -1.4
         }
       }
       nodes {
         leaf {
           scalar: -0.6
         }
       }
       nodes {
         leaf {
           scalar: 0.165
         }
       }
     }
     tree_weights: 0.15
     tree_weights: 1.0
     tree_metadata {
       num_layers_grown: 2
       is_finalized: true
     }
     tree_metadata {
       num_layers_grown: 1
       is_finalized: false
     }
     growing_metadata {
       num_trees_attempted: 2
       num_layers_attempted: 6
       last_layer_node_start: 16
       last_layer_node_end: 19
     }
   """, ensemble_proto)
         ensemble = boosted_trees_ops.TreeEnsemble(
             'ensemble',
             stamp_token=7,
             serialized_proto=ensemble_proto.SerializeToString())
         resources.initialize_resources(resources.shared_resources()).run()
         (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
          nodes_range) = ensemble.get_states()
         self.assertEqual(7, self.evaluate(stamp_token))
         self.assertEqual(2, self.evaluate(num_trees))
         self.assertEqual(1, self.evaluate(num_finalized_trees))
         self.assertEqual(6, self.evaluate(num_attempted_layers))
         self.assertAllEqual([16, 19], self.evaluate(nodes_range))
Ejemplo n.º 17
0
  def testPostPruningOfSomeNodes(self):
    """Test growing an ensemble with post-pruning."""
    with self.test_session() as session:
      # Create empty ensemble.
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle

      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare inputs.
      # Second feature has larger (but still negative gain).
      feature_ids = [0, 1]

      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([-1.3], dtype=np.float32)
      feature1_thresholds = np.array([7], dtype=np.int32)
      feature1_left_node_contribs = np.array([[0.013]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[0.0143]], dtype=np.float32)

      feature2_nodes = np.array([0], dtype=np.int32)
      feature2_gains = np.array([-0.2], dtype=np.float32)
      feature2_thresholds = np.array([33], dtype=np.int32)
      feature2_left_node_contribs = np.array([[0.01]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.0143]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes],
          gains=[feature1_gains, feature2_gains],
          thresholds=[feature1_thresholds, feature2_thresholds],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs
          ])

      session.run(grow_op)

      # Expect the split from second features to be chosen despite the negative
      # gain.
      # No pruning happened just yet.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            leaf {
              scalar: 0.0143
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, res_ensemble)

      # Prepare the second layer.
      # Note that node 1 gain is negative and node 2 gain is positive.
      feature_ids = [3]
      feature1_nodes = np.array([1, 2], dtype=np.int32)
      feature1_gains = np.array([-0.2, 0.5], dtype=np.float32)
      feature1_thresholds = np.array([7, 5], dtype=np.int32)
      feature1_left_node_contribs = np.array(
          [[0.07], [0.041]], dtype=np.float32)
      feature1_right_node_contribs = np.array(
          [[0.083], [0.064]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes],
          gains=[feature1_gains],
          thresholds=[feature1_thresholds],
          left_node_contribs=[feature1_left_node_contribs],
          right_node_contribs=[feature1_right_node_contribs])

      session.run(grow_op)

      # After adding this layer, the tree will not be finalized
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)
      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id:1
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 3
              threshold: 7
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: -0.2
              original_leaf {
                scalar: 0.01
               }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 3
              threshold: 5
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 0.5
              original_leaf {
                scalar: 0.0143
               }
            }
          }
          nodes {
            leaf {
              scalar: 0.08
            }
          }
          nodes {
            leaf {
              scalar: 0.093
            }
          }
          nodes {
            leaf {
              scalar: 0.0553
            }
          }
          nodes {
            leaf {
                scalar: 0.0783
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 2
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
       """
      self.assertEqual(new_stamp, 2)

      self.assertProtoEquals(expected_result, res_ensemble)
      # Now split the leaf 3, again with negative gain. After this layer, the
      # tree will be finalized, and post-pruning happens. The leafs 3,4,7,8 will
      # be pruned out.

      # Prepare the third layer.
      feature_ids = [92]
      feature1_nodes = np.array([3], dtype=np.int32)
      feature1_gains = np.array([-0.45], dtype=np.float32)
      feature1_thresholds = np.array([11], dtype=np.int32)
      feature1_left_node_contribs = np.array([[0.15]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[0.5]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=3,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes],
          gains=[feature1_gains],
          thresholds=[feature1_thresholds],
          left_node_contribs=[feature1_left_node_contribs],
          right_node_contribs=[feature1_right_node_contribs])

      session.run(grow_op)
      # After adding this layer, the tree will be finalized
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)
      # Node that nodes 3, 4, 7 and 8 got deleted, so metadata stores has ids
      # mapped to their parent node 1, with the respective change in logits.
      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id:1
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.2
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            bucketized_split {
              feature_id: 3
              threshold: 5
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 0.5
              original_leaf {
                scalar: 0.0143
               }
            }
          }
          nodes {
            leaf {
              scalar: 0.0553
            }
          }
          nodes {
            leaf {
              scalar: 0.0783
            }
          }
        }
        trees {
          nodes {
            leaf {
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 3
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 2
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.07
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.083
          }
          post_pruned_nodes_meta {
            new_node_id: 3
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 4
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.22
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.57
          }
        }
        tree_metadata {
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 3
        }
       """
      self.assertEqual(new_stamp, 3)
      self.assertProtoEquals(expected_result, res_ensemble)
Ejemplo n.º 18
0
    def testCachedPredictionIsCurrent(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 0.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 2]

            # We have two features: 0 and 1. Values don't matter because trees didn't
            # change.
            feature_0_values = [67, 5]
            feature_1_values = [9, 17]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # Nothing changed.
            self.assertAllClose(cached_tree_ids, new_tree_ids)
            self.assertAllClose(cached_node_ids, new_node_ids)
            self.assertAllClose([[0], [0]], logits_updates)
Ejemplo n.º 19
0
    def testCachedPredictionTheWholeTreeWasPruned(self):
        """Tests that prediction based on previous node in the tree works."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 0.00
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: -6.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 5.0
          }
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

            # Create existing ensemble.
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            cached_tree_ids = [
                0,
                0,
            ]
            # The predictions were cached in 1 and 2, both were pruned to the root.
            cached_node_ids = [1, 2]

            # We have two features: 0 and 1.These are not going to be used anywhere.
            feature_0_values = [12, 17]
            feature_1_values = [12, 12]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the last tree.
            self.assertAllClose([0, 0], new_tree_ids)
            self.assertAllClose([0, 0], new_node_ids)

            self.assertAllClose([[-6.0], [5.0]], logits_updates)
Ejemplo n.º 20
0
    def testCachedPredictionFromPreviousTree(self):
        """Tests the predictions work when we have cache from previous trees."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7
            }
          }
          nodes {
            leaf {
              scalar: 5
            }
          }
          nodes {
            leaf {
              scalar: 6
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: false
        }
        tree_weights: 0.1
        tree_weights: 0.1
        tree_weights: 0.1
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, one was cached in node 1 first, another in node 2.
            cached_tree_ids = [0, 0]
            cached_node_ids = [1, 0]

            # We have two features: 0 and 1.
            feature_0_values = [36, 32]
            feature_1_values = [11, 27]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)
            # Example 1 will get to node 3 in tree 1 and node 2 of tree 2
            # Example 2 will get to node 2 in tree 1 and node 1 of tree 2

            # We are in the last tree.
            self.assertAllClose([2, 2], new_tree_ids)
            # When using the full tree, the first example will end up in node 4,
            # the second in node 5.
            self.assertAllClose([2, 1], new_node_ids)
            # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = >
            #            change = 0.1*(5.0+5.0)
            # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
            #            change= 0.1(1.14+7.0-7.0)
            self.assertAllClose([[1], [0.114]], logits_updates)
Ejemplo n.º 21
0
    def testPredictionMultipleTreeMultiClass(self):
        """Tests the predictions work when we have multiple trees."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              vector: {
                value: 0.51
              }
              vector: {
                value: 1.14
              }
            }
          }
          nodes {
            leaf {
              vector: {
                value: 1.29
              }
              vector: {
                value: 8.79
              }
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              vector: {
                value: -4.33
              }
              vector: {
                value: 7.0
              }
            }
          }
          nodes {
            leaf {
              vector: {
                value: 0.2
              }
              vector: {
                value: 5.0
              }
            }
          }
          nodes {
            leaf {
              vector: {
                value: -4.1
              }
              vector: {
                value: 6.0
              }
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              vector: {
                value: 2.0
              }
              vector: {
                value: -7.0
              }
            }
          }
          nodes {
            leaf {
              vector: {
                value: 6.3
              }
              vector: {
                value: 5.0
              }
            }
          }
        }
        tree_weights: 0.1
        tree_weights: 0.2
        tree_weights: 1.0
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [11, 27]

            # Example 1: tree 0: (0.51, 1.14), tree 1: (0.2, 5.0), tree 2: (6.3, 5.0)
            #
            #            logits = (0.1*0.51+0.2*0.2+1*6.3,
            #                      0.1*1.14+0.2*5.0+1*5)
            # Example 2: tree 0: (0.51, 1.14), tree 1: (-4.33, 7.0), tree 2: (2.0, -7)
            #
            #            logits = (0.1*0.51+0.2*-4.33+1*2.0,
            #                      0.1*1.14+0.2*7.0+1*-7)
            logits_dimension = 2
            expected_logits = [[6.391, 6.114], [1.185, -5.486]]

            # Prediction should work fine.
            predict_op = boosted_trees_ops.predict(
                tree_ensemble_handle,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=logits_dimension)

            logits = session.run(predict_op)
            self.assertAllClose(expected_logits, logits)
Ejemplo n.º 22
0
    def testNoCachedPredictionButTreeExists(self):
        """Tests that predictions are updated once trees are added."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 1
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # Two examples, none were cached before.
            cached_tree_ids = [0, 0]
            cached_node_ids = [0, 0]

            feature_0_values = [67, 5]

            # Grow tree ensemble.
            predict_op = boosted_trees_ops.training_predict(
                tree_ensemble_handle,
                cached_tree_ids=cached_tree_ids,
                cached_node_ids=cached_node_ids,
                bucketized_features=[feature_0_values],
                logits_dimension=1)

            logits_updates, new_tree_ids, new_node_ids = session.run(
                predict_op)

            # We are in the first tree.
            self.assertAllClose([0, 0], new_tree_ids)
            self.assertAllClose([2, 1], new_node_ids)
            self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
Ejemplo n.º 23
0
  def testPostPruningOfAllNodes(self):
    """Test growing an ensemble with post-pruning, with all nodes are pruned."""
    with self.test_session() as session:
      # Create empty ensemble.
      # Create empty ensemble.
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle

      resources.initialize_resources(resources.shared_resources()).run()

      # Prepare inputs. All have negative gains.
      feature_ids = [0, 1]

      feature1_nodes = np.array([0], dtype=np.int32)
      feature1_gains = np.array([-1.3], dtype=np.float32)
      feature1_thresholds = np.array([7], dtype=np.int32)
      feature1_left_node_contribs = np.array([[0.013]], dtype=np.float32)
      feature1_right_node_contribs = np.array([[0.0143]], dtype=np.float32)

      feature2_nodes = np.array([0], dtype=np.int32)
      feature2_gains = np.array([-0.62], dtype=np.float32)
      feature2_thresholds = np.array([33], dtype=np.int32)
      feature2_left_node_contribs = np.array([[0.01]], dtype=np.float32)
      feature2_right_node_contribs = np.array([[0.0143]], dtype=np.float32)

      # Grow tree ensemble.
      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=2,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes, feature2_nodes],
          gains=[feature1_gains, feature2_gains],
          thresholds=[feature1_thresholds, feature2_thresholds],
          left_node_contribs=[
              feature1_left_node_contribs, feature2_left_node_contribs
          ],
          right_node_contribs=[
              feature1_right_node_contribs, feature2_right_node_contribs
          ])

      session.run(grow_op)

      # Expect the split from feature 2 to be chosen despite the negative gain.
      # The grown tree should not be finalized as max tree depth is 2 so no
      # pruning occurs.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)

      expected_result = """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 33
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: -0.62
            }
          }
          nodes {
            leaf {
              scalar: 0.01
            }
          }
          nodes {
            leaf {
              scalar: 0.0143
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """
      self.assertEqual(new_stamp, 1)
      self.assertProtoEquals(expected_result, res_ensemble)

      # Prepare inputs.
      # All have negative gain.
      feature_ids = [3]
      feature1_nodes = np.array([1, 2], dtype=np.int32)
      feature1_gains = np.array([-0.2, -0.5], dtype=np.float32)
      feature1_thresholds = np.array([77, 79], dtype=np.int32)
      feature1_left_node_contribs = np.array([[0.023], [0.3]], dtype=np.float32)
      feature1_right_node_contribs = np.array(
          [[0.012343], [24]], dtype=np.float32)

      grow_op = boosted_trees_ops.update_ensemble(
          tree_ensemble_handle,
          learning_rate=1.0,
          pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
          max_depth=2,
          feature_ids=feature_ids,
          node_ids=[feature1_nodes],
          gains=[feature1_gains],
          thresholds=[feature1_thresholds],
          left_node_contribs=[feature1_left_node_contribs],
          right_node_contribs=[feature1_right_node_contribs])

      session.run(grow_op)

      # Expect the split from feature 1 to be chosen despite the negative gain.
      # The grown tree should be finalized. Since all nodes have negative gain,
      # the whole tree is pruned.
      new_stamp, serialized = session.run(tree_ensemble.serialize())
      res_ensemble = boosted_trees_pb2.TreeEnsemble()
      res_ensemble.ParseFromString(serialized)

      # Expect the ensemble to be empty as post-pruning will prune
      # the entire finalized tree.
      self.assertEqual(new_stamp, 2)
      self.assertProtoEquals("""
      trees {
        nodes {
          leaf {
          }
        }
      }
      trees {
        nodes {
          leaf {
          }
        }
      }
      tree_weights: 1.0
      tree_weights: 1.0
      tree_metadata{
        num_layers_grown: 2
        is_finalized: true
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: 0.0
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.01
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.0143
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.033
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.022343
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -0.3143
        }
        post_pruned_nodes_meta {
          new_node_id: 0
          logit_change: -24.0143
        }
      }
      tree_metadata {
      }
      growing_metadata {
        num_trees_attempted: 1
        num_layers_attempted: 2
      }
      """, res_ensemble)
Ejemplo n.º 24
0
    def testPredictionMultipleTree(self):
        """Tests the predictions work when we have multiple trees."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_weights: 0.1
        tree_weights: 0.2
        tree_weights: 1.0
      """, tree_ensemble_config)

            # Create existing ensemble with one root split
            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [11, 27]

            # Example 1: tree 0: 1.14, tree 1: 5.0, tree 2: 5.0 = >
            #            logit = 0.1*5.0+0.2*5.0+1*5
            # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
            #            logit= 0.1*1.14+0.2*7.0-1*7.0
            expected_logits = [[6.114], [-5.486]]

            # Do with parallelization, e.g. EVAL
            predict_op = boosted_trees_ops.predict(
                tree_ensemble_handle,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits = session.run(predict_op)
            self.assertAllClose(expected_logits, logits)

            # Do without parallelization, e.g. INFER - the result is the same
            predict_op = boosted_trees_ops.predict(
                tree_ensemble_handle,
                bucketized_features=[feature_0_values, feature_1_values],
                logits_dimension=1)

            logits = session.run(predict_op)
            self.assertAllClose(expected_logits, logits)
Ejemplo n.º 25
0
  def testCategoricalSplits(self):
    """Tests the training prediction work for categorical splits."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            categorical_split {
              feature_id: 1
              value: 2
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            categorical_split {
              feature_id: 0
              value: 13
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          is_finalized: true
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_0_values = [13, 1, 3]
      feature_1_values = [2, 2, 1]

      # No previous cached values.
      cached_tree_ids = [0, 0, 0]
      cached_node_ids = [0, 0, 0]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      self.assertAllClose([0, 0, 0], new_tree_ids)
      self.assertAllClose([3, 4, 2], new_node_ids)
      self.assertAllClose([[5.], [6.], [7.]], logits_updates)