Пример #1
0
    def testContribsForOnlyABiasNode(self):
        """Tests case when, after training, only left with a bias node.

    For example, this could happen if the final ensemble contains one tree that
    got pruned up to the root.
    """
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 1.72
            }
          }
        }
        tree_weights: 0.1
        tree_metadata: {
          num_layers_grown: 0
        }
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            # All features are unused.
            feature_0_values = [36, 32]
            feature_1_values = [13, -29]
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            bias = 1.72 * 0.1  # Root node of tree_0.
            expected_feature_ids = ((), ())
            expected_logits_paths = ((bias, ), (bias, ))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)
Пример #2
0
def _parse_debug_proto_string(example_proto_serialized):
    example_debug_outputs = boosted_trees_pb2.DebugOutput()
    example_debug_outputs.ParseFromString(example_proto_serialized)
    feature_ids = example_debug_outputs.feature_ids
    logits_path = example_debug_outputs.logits_path
    return feature_ids, logits_path
Пример #3
0
    def testContribsMultipleTree(self):
        """Tests that the contribs work when we have multiple trees."""
        with self.cached_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 28
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf: {scalar: 2.1}
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
            metadata {
              original_leaf: {scalar: 5.5}
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_weights: 0.1
        tree_weights: 0.2
        tree_weights: 1.0
        tree_metadata: {
          num_layers_grown: 1}
        tree_metadata: {
          num_layers_grown: 2}
        tree_metadata: {
          num_layers_grown: 1}
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [13, -29
                                ]  # Unused. Feature is not in above ensemble.
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            bias = 2.1 * 0.1  # Root node of tree_0.
            expected_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
            # example_0 :  (bias, 0.1 * 1.14, 0.2 * 5.5 + .114, 0.2 * 5. + .114,
            # 1.0 * 5.0 + 0.2 * 5. + .114)
            # example_1 :  (bias, 0.1 * 1.14, 0.2 * 7 + .114,
            # 1.0 * -7. + 0.2 * 7 + .114)
            expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114),
                                     (bias, 0.114, 1.514, -5.486))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)
Пример #4
0
    def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
        """Tests case when, after training, first tree contains only a bias node."""
        with self.test_session() as session:
            tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
            text_format.Merge(
                """
        trees {
          nodes {
            leaf {
              scalar: 1.72
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 2
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
            metadata {
              original_leaf: {scalar: 5.5}
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.
        tree_weights: 0.1
        tree_metadata: {
          num_layers_grown: 0
        }
        tree_metadata: {
          num_layers_grown: 1
        }
      """, tree_ensemble_config)

            tree_ensemble = boosted_trees_ops.TreeEnsemble(
                'ensemble',
                serialized_proto=tree_ensemble_config.SerializeToString())
            tree_ensemble_handle = tree_ensemble.resource_handle
            resources.initialize_resources(resources.shared_resources()).run()

            feature_0_values = [36, 32]
            feature_1_values = [13, -29]  # Unused feature.
            feature_2_values = [11, 27]

            # Expected logits are computed by traversing the logit path and
            # subtracting child logits from parent logits.
            expected_feature_ids = ((2, 0), (2, ))
            # bias = 1.72 * 1.  # Root node of tree_0.
            # example_0 :  (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
            # example_1 :  (bias, 0.1 * 7. + bias )
            expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))

            bucketized_features = [
                feature_0_values, feature_1_values, feature_2_values
            ]

            debug_op = boosted_trees_ops.example_debug_outputs(
                tree_ensemble_handle,
                bucketized_features=bucketized_features,
                logits_dimension=1)

            serialized_examples_debug_outputs = session.run(debug_op)
            feature_ids = []
            logits_paths = []
            for example in serialized_examples_debug_outputs:
                example_debug_outputs = boosted_trees_pb2.DebugOutput()
                example_debug_outputs.ParseFromString(example)
                feature_ids.append(example_debug_outputs.feature_ids)
                logits_paths.append(example_debug_outputs.logits_path)

            self.assertAllClose(feature_ids, expected_feature_ids)
            self.assertAllClose(logits_paths, expected_logits_paths)