def setUp(self): """Create an ensemble of 2 trees.""" super(EnsembleOptimizerOpsTest, self).setUp() self._tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() # First tree. tree_1 = self._tree_ensemble.trees.add() _append_to_leaf(tree_1.nodes.add().leaf, 0, 0.4) _append_to_leaf(tree_1.nodes.add().leaf, 1, 0.6) # Second tree. tree_2 = self._tree_ensemble.trees.add() _append_to_leaf(tree_2.nodes.add().leaf, 0, 1) _append_to_leaf(tree_2.nodes.add().leaf, 1, 0) self._tree_ensemble.tree_weights.append(1.0) self._tree_ensemble.tree_weights.append(1.0) meta_1 = self._tree_ensemble.tree_metadata.add() meta_1.num_tree_weight_updates = 2 meta_2 = self._tree_ensemble.tree_metadata.add() meta_2.num_tree_weight_updates = 3 # Ensemble to be added. self._ensemble_to_add = tree_config_pb2.DecisionTreeEnsembleConfig() self._tree_to_add = self._ensemble_to_add.trees.add() _append_to_leaf(self._tree_to_add.nodes.add().leaf, 0, 0.3) _append_to_leaf(self._tree_to_add.nodes.add().leaf, 1, 0.7)
def testWithExistingEnsembleAndShrinkage(self): with self.test_session(): # Add shrinkage config. learning_rate = 0.0001 tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 10 trees with some weights. for i in range(0, 5): tree = tree_ensemble.trees.add() _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) tree_ensemble.tree_weights.append(i + 1) meta = tree_ensemble.tree_metadata.add() meta.num_tree_weight_updates = 1 tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config=tree_ensemble.SerializeToString(), name="existing") # Create non-zero feature importance. feature_usage_counts = variables.Variable( initial_value=np.array([4, 7], np.int64), name="feature_usage_counts", trainable=False) feature_gains = variables.Variable(initial_value=np.array( [0.2, 0.8], np.float32), name="feature_gains", trainable=False) resources.initialize_resources(resources.shared_resources()).run() variables.initialize_all_variables().run() output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() with ops.control_dependencies([ ensemble_optimizer_ops.add_trees_to_ensemble( tree_ensemble_handle, self._ensemble_to_add.SerializeToString(), feature_usage_counts, [1, 2], feature_gains, [0.5, 0.3], [[], []], learning_rate=learning_rate) ]): output_ensemble.ParseFromString( model_ops.tree_ensemble_serialize(tree_ensemble_handle) [1].eval()) # The weights of previous trees stayed the same, new tree (LAST) is added # with shrinkage weight. self.assertAllClose([1.0, 2.0, 3.0, 4.0, 5.0, learning_rate], output_ensemble.tree_weights) # Check that all number of updates are equal to 1 (e,g, no old tree weight # got adjusted. for i in range(0, 6): self.assertEqual( 1, output_ensemble.tree_metadata[i].num_tree_weight_updates) # Ensure feature importance was aggregated correctly. self.assertAllEqual([5, 9], feature_usage_counts.eval()) self.assertArrayNear( [0.2 + 0.5 * learning_rate, 0.8 + 0.3 * learning_rate], feature_gains.eval(), 1e-6)
def testCreate(self): with self.test_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree = tree_ensemble_config.trees.add() _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) tree_ensemble_config.tree_weights.append(1.0) # Prepare learner config. learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=3, tree_ensemble_config=tree_ensemble_config.SerializeToString(), name="create_tree") resources.initialize_resources(resources.shared_resources()).run() result, _, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [self._sparse_float_indices1, self._sparse_float_indices2], [self._sparse_float_values1, self._sparse_float_values2], [self._sparse_float_shape1, self._sparse_float_shape2], [self._sparse_int_indices1], [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, center_bias=False, reduce_dim=True) self.assertAllClose(result.eval(), [[-0.4], [-0.4]]) stamp_token = model_ops.tree_ensemble_stamp_token( tree_ensemble_handle) self.assertEqual(stamp_token.eval(), 3)
def testTrainFnNonChiefWithCentering(self): """Tests the train function running on worker with bias centering.""" with self.test_session(): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 learner_config.num_classes = 2 learner_config.regularization.l1 = 0 learner_config.regularization.l2 = 0 learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 features = {} features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( is_chief=False, num_ps_replicas=0, center_bias=True, ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, features=features) predictions = array_ops.constant([[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) ensemble_stamp = variables.Variable(initial_value=0, name="ensemble_stamp", trainable=False, dtype=dtypes.int64) predictions_dict = { "predictions": predictions, "predictions_no_dropout": predictions, "partition_ids": partition_ids, "ensemble_stamp": ensemble_stamp } labels = array_ops.ones([4, 1], dtypes.float32) weights = array_ops.ones([4, 1], dtypes.float32) # Create train op. train_op = gbdt_model.train(loss=math_ops.reduce_mean( _squared_loss(labels, weights, predictions)), predictions_dict=predictions_dict, labels=labels) variables.global_variables_initializer().run() resources.initialize_resources(resources.shared_resources()).run() # Regardless of how many times the train op is run, a non-chief worker # can only accumulate stats so the tree ensemble never changes. for _ in range(5): train_op.run() stamp_token, serialized = model_ops.tree_ensemble_serialize( ensemble_handle) output = tree_config_pb2.DecisionTreeEnsembleConfig() output.ParseFromString(serialized.eval()) self.assertEquals(len(output.trees), 0) self.assertEquals(len(output.tree_weights), 0) self.assertEquals(stamp_token.eval(), 0)
def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None): """A wrapper to export to SavedModel, and convert it to other formats.""" result_dir = base_strategy.export(estimator, export_dir, checkpoint_path, eval_result) with ops.Graph().as_default() as graph: with tf_session.Session(graph=graph) as sess: saved_model_loader.load( sess, [tag_constants.SERVING], result_dir) # Note: This is GTFlow internal API and might change. ensemble_model = graph.get_operation_by_name( "ensemble_model/TreeEnsembleSerialize") _, dfec_str = sess.run(ensemble_model.outputs) dtec = tree_config_pb2.DecisionTreeEnsembleConfig() dtec.ParseFromString(dfec_str) # Export the result in the same folder as the saved model. if convert_fn: convert_fn(dtec, sorted_feature_names, len(dense_floats), len(sparse_float_indices), len(sparse_int_indices), result_dir, eval_result) feature_importances = _get_feature_importances( dtec, sorted_feature_names, len(dense_floats), len(sparse_float_indices), len(sparse_int_indices)) sorted_by_importance = sorted( feature_importances.items(), key=lambda x: -x[1]) assets_dir = os.path.join( compat.as_bytes(result_dir), compat.as_bytes("assets.extra")) gfile.MakeDirs(assets_dir) with gfile.GFile(os.path.join( compat.as_bytes(assets_dir), compat.as_bytes("feature_importances")), "w") as f: f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance)) return result_dir
def _assert_checkpoint_and_return_model(self, model_dir, global_step): reader = checkpoint_utils.load_checkpoint(model_dir) self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) serialized = reader.get_tensor("ensemble_model:0_config") ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig() ensemble_proto.ParseFromString(serialized) return ensemble_proto
def testWithEmptyEnsembleAndShrinkage(self): with self.test_session(): # Add shrinkage config. learning_rate = 0.0001 tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config=tree_ensemble.SerializeToString(), name="existing") # Create zero feature importance. feature_usage_counts = variables.Variable( initial_value=np.array([0, 0], np.int64), name="feature_usage_counts", trainable=False) feature_gains = variables.Variable(initial_value=np.array( [0.0, 0.0], np.float32), name="feature_gains", trainable=False) resources.initialize_resources(resources.shared_resources()).run() variables.initialize_all_variables().run() output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() with ops.control_dependencies([ ensemble_optimizer_ops.add_trees_to_ensemble( tree_ensemble_handle, self._ensemble_to_add.SerializeToString(), feature_usage_counts, [1, 2], feature_gains, [0.5, 0.3], [[], []], learning_rate=learning_rate) ]): output_ensemble.ParseFromString( model_ops.tree_ensemble_serialize(tree_ensemble_handle) [1].eval()) # New tree is added with shrinkage weight. self.assertAllClose([learning_rate], output_ensemble.tree_weights) self.assertEqual( 1, output_ensemble.tree_metadata[0].num_tree_weight_updates) self.assertAllEqual([1, 2], feature_usage_counts.eval()) self.assertArrayNear([0.5 * learning_rate, 0.3 * learning_rate], feature_gains.eval(), 1e-6)
def testPredictFn(self): """Tests the predict function.""" with self.test_session() as sess: # Create ensemble with one bias node. ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( """ trees { nodes { leaf { vector { value: 0.25 } } } } tree_weights: 1.0 tree_metadata { num_tree_weight_updates: 1 num_layers_grown: 1 is_finalized: true }""", ensemble_config) ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=3, tree_ensemble_config=ensemble_config.SerializeToString(), name="tree_ensemble") resources.initialize_resources(resources.shared_resources()).run() learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 learner_config.num_classes = 2 learner_config.regularization.l1 = 0 learner_config.regularization.l2 = 0 learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 features = {} features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( is_chief=False, num_ps_replicas=0, center_bias=True, ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, features=features) # Create predict op. mode = model_fn.ModeKeys.EVAL predictions_dict = sess.run(gbdt_model.predict(mode)) self.assertEquals(predictions_dict["ensemble_stamp"], 3) self.assertAllClose(predictions_dict["predictions"], [[0.25], [0.25], [0.25], [0.25]]) self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0])
def testUsedHandlers(self): with self.cached_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree_ensemble_config.growing_metadata.used_handler_ids.append(1) tree_ensemble_config.growing_metadata.used_handler_ids.append(5) stamp_token = 3 tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=stamp_token, tree_ensemble_config=tree_ensemble_config.SerializeToString(), name="create_tree") resources.initialize_resources(resources.shared_resources()).run() result = model_ops.tree_ensemble_used_handlers( tree_ensemble_handle, stamp_token, num_all_handlers=6) self.assertAllEqual([0, 1, 0, 0, 0, 1], result.used_handlers_mask.eval()) self.assertEqual(2, result.num_used_handlers.eval())
def testWithExistingEnsemble(self): with self.test_session(): # Create existing tree ensemble. tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config=self._tree_ensemble.SerializeToString(), name="existing") # Create non-zero feature importance. feature_usage_counts = variables.Variable( initial_value=np.array([0, 4, 1], np.int64), name="feature_usage_counts", trainable=False) feature_gains = variables.Variable(initial_value=np.array( [0.0, 0.3, 0.05], np.float32), name="feature_gains", trainable=False) resources.initialize_resources(resources.shared_resources()).run() variables.initialize_all_variables().run() output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() with ops.control_dependencies([ ensemble_optimizer_ops.add_trees_to_ensemble( tree_ensemble_handle, self._ensemble_to_add.SerializeToString(), feature_usage_counts, [1, 2, 0], feature_gains, [0.02, 0.1, 0.0], [[], []], learning_rate=1) ]): output_ensemble.ParseFromString( model_ops.tree_ensemble_serialize(tree_ensemble_handle) [1].eval()) # Output. self.assertEqual(3, len(output_ensemble.trees)) self.assertProtoEquals(self._tree_to_add, output_ensemble.trees[2]) self.assertAllEqual([1.0, 1.0, 1.0], output_ensemble.tree_weights) self.assertEqual( 2, output_ensemble.tree_metadata[0].num_tree_weight_updates) self.assertEqual( 3, output_ensemble.tree_metadata[1].num_tree_weight_updates) self.assertEqual( 1, output_ensemble.tree_metadata[2].num_tree_weight_updates) self.assertAllEqual([1, 6, 1], feature_usage_counts.eval()) self.assertArrayNear([0.02, 0.4, 0.05], feature_gains.eval(), 1e-6)
def testWithEmptyEnsemble(self): with self.test_session(): # Create an empty ensemble. tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="empty") # Create zero feature importance. feature_usage_counts = variables.Variable( initial_value=array_ops.zeros([1], dtypes.int64), name="feature_usage_counts", trainable=False) feature_gains = variables.Variable(initial_value=array_ops.zeros( [1], dtypes.float32), name="feature_gains", trainable=False) resources.initialize_resources(resources.shared_resources()).run() variables.initialize_all_variables().run() with ops.control_dependencies([ ensemble_optimizer_ops.add_trees_to_ensemble( tree_ensemble_handle, self._ensemble_to_add.SerializeToString(), feature_usage_counts, [2], feature_gains, [0.4], [[]], learning_rate=1.0) ]): result = model_ops.tree_ensemble_serialize( tree_ensemble_handle)[1] # Output. output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() output_ensemble.ParseFromString(result.eval()) self.assertProtoEquals(self._tree_to_add, output_ensemble.trees[0]) self.assertEqual(1, len(output_ensemble.trees)) self.assertAllEqual([1.0], output_ensemble.tree_weights) self.assertEqual( 1, output_ensemble.tree_metadata[0].num_tree_weight_updates) self.assertAllEqual([2], feature_usage_counts.eval()) self.assertArrayNear([0.4], feature_gains.eval(), 1e-6)
def export_fn(estimator, export_dir, checkpoint_path, eval_result=None): """A wrapper to export to SavedModel, and convert it to other formats.""" result_dir = base_strategy.export(estimator, export_dir, checkpoint_path, eval_result) with ops.Graph().as_default() as graph: with tf_session.Session(graph=graph) as sess: saved_model_loader.load( sess, [tag_constants.SERVING], result_dir) # Note: This is GTFlow internal API and might change. ensemble_model = graph.get_operation_by_name( "ensemble_model/TreeEnsembleSerialize") _, dfec_str = sess.run(ensemble_model.outputs) dtec = tree_config_pb2.DecisionTreeEnsembleConfig() dtec.ParseFromString(dfec_str) # Export the result in the same folder as the saved model. convert_fn(dtec, sorted_feature_names, len(dense_floats), len(sparse_float_indices), len(sparse_int_indices), result_dir, eval_result) return result_dir
def testTrainFnChiefScalingNumberOfExamples(self): """Tests the train function running on chief without bias centering.""" with self.test_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 learner_config.num_classes = 2 learner_config.regularization.l1 = 0 learner_config.regularization.l2 = 0 learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 num_examples_fn = (lambda layer: math_ops.pow( math_ops.cast(2, dtypes.int64), layer) * 1) features = {} features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( is_chief=True, num_ps_replicas=0, center_bias=False, ensemble_handle=ensemble_handle, examples_per_layer=num_examples_fn, learner_config=learner_config, features=features) predictions = array_ops.constant([[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) ensemble_stamp = variables.Variable(initial_value=0, name="ensemble_stamp", trainable=False, dtype=dtypes.int64) predictions_dict = { "predictions": predictions, "predictions_no_dropout": predictions, "partition_ids": partition_ids, "ensemble_stamp": ensemble_stamp, "num_trees": 12, } labels = array_ops.ones([4, 1], dtypes.float32) weights = array_ops.ones([4, 1], dtypes.float32) # Create train op. train_op = gbdt_model.train(loss=math_ops.reduce_mean( _squared_loss(labels, weights, predictions)), predictions_dict=predictions_dict, labels=labels) variables.global_variables_initializer().run() resources.initialize_resources(resources.shared_resources()).run() # On first run, expect no splits to be chosen because the quantile # buckets will not be ready. train_op.run() stamp_token, serialized = model_ops.tree_ensemble_serialize( ensemble_handle) output = tree_config_pb2.DecisionTreeEnsembleConfig() output.ParseFromString(serialized.eval()) self.assertEquals(len(output.trees), 0) self.assertEquals(len(output.tree_weights), 0) self.assertEquals(stamp_token.eval(), 1) # Update the stamp to be able to run a second time. sess.run([ensemble_stamp.assign_add(1)]) # On second run, expect a trivial split to be chosen to basically # predict the average. train_op.run() stamp_token, serialized = model_ops.tree_ensemble_serialize( ensemble_handle) output = tree_config_pb2.DecisionTreeEnsembleConfig() output.ParseFromString(serialized.eval()) self.assertEquals(len(output.trees), 1) self.assertAllClose(output.tree_weights, [0.1]) self.assertEquals(stamp_token.eval(), 2) expected_tree = """ nodes { dense_float_binary_split { threshold: 1.0 left_id: 1 right_id: 2 } node_metadata { gain: 0 } } nodes { leaf { vector { value: 0.25 } } } nodes { leaf { vector { value: 0.0 } } }""" self.assertProtoEquals(expected_tree, output.trees[0])
def testSerialization(self): with ops.Graph().as_default() as graph: with self.test_session(graph): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig( ) # Bias tree only for second class. tree1 = tree_ensemble_config.trees.add() _append_to_leaf(tree1.nodes.add().leaf, 1, -0.2) tree_ensemble_config.tree_weights.append(1.0) # Depth 2 tree. tree2 = tree_ensemble_config.trees.add() tree_ensemble_config.tree_weights.append(1.0) _set_float_split( tree2.nodes.add().sparse_float_binary_split_default_right. split, 1, 4.0, 1, 2) _set_float_split(tree2.nodes.add().dense_float_binary_split, 0, 9.0, 3, 4) _append_to_leaf(tree2.nodes.add().leaf, 0, 0.5) _append_to_leaf(tree2.nodes.add().leaf, 1, 1.2) _append_to_leaf(tree2.nodes.add().leaf, 0, -0.9) tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=7, tree_ensemble_config=tree_ensemble_config. SerializeToString(), name="saver_tree") stamp_token, serialized_config = model_ops.tree_ensemble_serialize( tree_ensemble_handle) resources.initialize_resources( resources.shared_resources()).run() self.assertEqual(stamp_token.eval(), 7) serialized_config = serialized_config.eval() with ops.Graph().as_default() as graph: with self.test_session(graph): tree_ensemble_handle2 = model_ops.tree_ensemble_variable( stamp_token=9, tree_ensemble_config=serialized_config, name="saver_tree2") resources.initialize_resources( resources.shared_resources()).run() # Prepare learner config. learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 3 result, _, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle2, self._seed, [self._dense_float_tensor], [self._sparse_float_indices1, self._sparse_float_indices2], [self._sparse_float_values1, self._sparse_float_values2], [self._sparse_float_shape1, self._sparse_float_shape2], [self._sparse_int_indices1], [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, center_bias=False, reduce_dim=True) # Re-serialize tree. stamp_token2, serialized_config2 = model_ops.tree_ensemble_serialize( tree_ensemble_handle2) # The first example will get bias class 1 -0.2 from first tree and # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], # the second example will get the same bias class 1 -0.2 and leaf 3 # payload of class 1 1.2 hence [0.0, 1.0]. self.assertEqual(stamp_token2.eval(), 9) # Class 2 does have scores in the leaf => it gets score 0. self.assertEqual(serialized_config2.eval(), serialized_config) self.assertAllClose(result.eval(), [[0.5, -0.2], [0, 1.0]])
def testRestore(self): # Calling self.test_session() without a graph specified results in # TensorFlowTestCase caching the session and returning the same one # every time. In this test, we need to create two different sessions # which is why we also create a graph and pass it to self.test_session() # to ensure no caching occurs under the hood. save_path = os.path.join(self.get_temp_dir(), "restore-test") with ops.Graph().as_default() as graph: with self.test_session(graph) as sess: tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig( ) tree = tree_ensemble_config.trees.add() tree_ensemble_config.tree_metadata.add().is_finalized = True tree_ensemble_config.tree_weights.append(1.0) _append_to_leaf(tree.nodes.add().leaf, 0, -0.1) tree_ensemble_config2 = tree_config_pb2.DecisionTreeEnsembleConfig( ) tree2 = tree_ensemble_config2.trees.add() tree_ensemble_config.tree_weights.append(1.0) _append_to_leaf(tree2.nodes.add().leaf, 0, -1.0) tree_ensemble_config3 = tree_config_pb2.DecisionTreeEnsembleConfig( ) tree3 = tree_ensemble_config3.trees.add() tree_ensemble_config.tree_weights.append(1.0) _append_to_leaf(tree3.nodes.add().leaf, 0, -10.0) # Prepare learner config. learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=3, tree_ensemble_config=tree_ensemble_config. SerializeToString(), name="restore_tree") feature_usage_counts = variables.Variable( initial_value=array_ops.zeros([1], dtypes.int64), name="feature_usage_counts", trainable=False) feature_gains = variables.Variable( initial_value=array_ops.zeros([1], dtypes.float32), name="feature_gains", trainable=False) resources.initialize_resources( resources.shared_resources()).run() variables.initialize_all_variables().run() my_saver = saver.Saver() with ops.control_dependencies([ ensemble_optimizer_ops.add_trees_to_ensemble( tree_ensemble_handle, tree_ensemble_config2.SerializeToString(), feature_usage_counts, [0], feature_gains, [0], [[]], learning_rate=1) ]): result, _, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 ], [ self._sparse_float_values1, self._sparse_float_values2 ], [self._sparse_float_shape1, self._sparse_float_shape2], [self._sparse_int_indices1], [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, center_bias=False, reduce_dim=True) self.assertAllClose([[-1.1], [-1.1]], result.eval()) # Save before adding other trees. val = my_saver.save(sess, save_path) self.assertEqual(save_path, val) # Add more trees after saving. with ops.control_dependencies([ ensemble_optimizer_ops.add_trees_to_ensemble( tree_ensemble_handle, tree_ensemble_config3.SerializeToString(), feature_usage_counts, [0], feature_gains, [0], [[]], learning_rate=1) ]): result, _, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 ], [ self._sparse_float_values1, self._sparse_float_values2 ], [self._sparse_float_shape1, self._sparse_float_shape2], [self._sparse_int_indices1], [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, center_bias=False, reduce_dim=True) self.assertAllClose(result.eval(), [[-11.1], [-11.1]]) # Start a second session. In that session the parameter nodes # have not been initialized either. with ops.Graph().as_default() as graph: with self.test_session(graph) as sess: tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="restore_tree") my_saver = saver.Saver() my_saver.restore(sess, save_path) result, _, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [self._sparse_float_indices1, self._sparse_float_indices2], [self._sparse_float_values1, self._sparse_float_values2], [self._sparse_float_shape1, self._sparse_float_shape2], [self._sparse_int_indices1], [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, center_bias=False, reduce_dim=True) # Make sure we only have the first and second tree. # The third tree was added after the save. self.assertAllClose(result.eval(), [[-1.1], [-1.1]])
def _make_trees(self): dtec_str = """ trees { nodes { leaf { vector { value: -1 } } } } trees { nodes { dense_float_binary_split { feature_column: 0 threshold: 1740.0 left_id: 1 right_id: 2 } node_metadata { gain: 500 } } nodes { leaf { vector { value: 0.6 } } } nodes { sparse_float_binary_split_default_left { split { feature_column: 0 threshold: 1500.0 left_id: 3 right_id: 4 } } node_metadata { gain: 500 } } nodes { categorical_id_binary_split { feature_column: 0 feature_id: 5 left_id: 5 right_id: 6 } node_metadata { gain: 500 } } nodes { leaf { vector { value: 0.8 } } } nodes { leaf { vector { value: 0.5 } } } nodes { leaf { vector { value: 0.3 } } } } tree_weights: 1.0 tree_weights: 0.1 """ dtec = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge(dtec_str, dtec) feature_columns = ["feature_b", "feature_a", "feature_d"] return dtec, feature_columns
def testConvertModel(self): dtec_str = """ trees { nodes { leaf { vector { value: -1 } } } } trees { nodes { dense_float_binary_split { feature_column: 0 threshold: 1740.0 left_id: 1 right_id: 2 } node_metadata { gain: 500 } } nodes { leaf { vector { value: 0.6 } } } nodes { sparse_float_binary_split_default_left { split { feature_column: 0 threshold: 1500.0 left_id: 3 right_id: 4 } } node_metadata { gain: 500 } } nodes { categorical_id_binary_split { feature_column: 0 feature_id: 5 left_id: 5 right_id: 6 } node_metadata { gain: 500 } } nodes { leaf { vector { value: 0.8 } } } nodes { leaf { vector { value: 0.5 } } } nodes { leaf { vector { value: 0.3 } } } } tree_weights: 1.0 tree_weights: 0.1 """ dtec = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge(dtec_str, dtec) # The feature columns in the order they were added. feature_columns = ["feature_b", "feature_a", "feature_d"] out = custom_export_strategy.convert_to_universal_format( dtec, feature_columns, 1, 1, 1) expected_tree = """ features { key: "feature_a" } features { key: "feature_b" } features { key: "feature_d" } model { ensemble { summation_combination_technique { } members { submodel { decision_tree { nodes { node_id { } leaf { vector { value { float_value: -1.0 } } } } } } submodel_id { } } members { submodel { decision_tree { nodes { node_id { } binary_node { left_child_id { value: 1 } right_child_id { value: 2 } inequality_left_child_test { feature_id { id { value: "feature_b" } } threshold { float_value: 1740.0 } } } } nodes { node_id { value: 1 } leaf { vector { value { float_value: 0.06 } } } } nodes { node_id { value: 2 } binary_node { left_child_id { value: 3 } right_child_id { value: 4 } inequality_left_child_test { feature_id { id { value: "feature_a" } } threshold { float_value: 1500.0 } } } } nodes { node_id { value: 3 } binary_node { left_child_id { value: 5 } right_child_id { value: 6 } default_direction: RIGHT custom_left_child_test { [type.googleapis.com/tensorflow.decision_trees.MatchingValuesTest] { feature_id { id { value: "feature_d" } } value { int64_value: 5 } } } } } nodes { node_id { value: 4 } leaf { vector { value { float_value: 0.08 } } } } nodes { node_id { value: 5 } leaf { vector { value { float_value: 0.05 } } } } nodes { node_id { value: 6 } leaf { vector { value { float_value: 0.03 } } } } } } submodel_id { value: 1 } } } }""" self.assertProtoEquals(expected_tree, out)
def testTrainFnMulticlassTreePerClass(self): """Tests the GBDT train for multiclass tree per class strategy.""" with self.test_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 1 # Use full hessian multiclass strategy. learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.TREE_PER_CLASS) learner_config.num_classes = 5 learner_config.regularization.l1 = 0 # To make matrix inversible. learner_config.regularization.l2 = 1e-5 learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 features = { "dense_float": array_ops.constant([[1.0], [1.5], [2.0]], dtypes.float32), } gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( is_chief=True, num_ps_replicas=0, center_bias=False, ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, features=features) batch_size = 3 predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], [0.0, 0.0, 0.0, 2.0, 1.2]], dtype=dtypes.float32) labels = array_ops.constant([[2], [2], [3]], dtype=dtypes.float32) weights = array_ops.ones([batch_size, 1], dtypes.float32) partition_ids = array_ops.zeros([batch_size], dtypes.int32) ensemble_stamp = variables.Variable(initial_value=0, name="ensemble_stamp", trainable=False, dtype=dtypes.int64) predictions_dict = { "predictions": predictions, "predictions_no_dropout": predictions, "partition_ids": partition_ids, "ensemble_stamp": ensemble_stamp, # This should result in a tree built for a class 2. "num_trees": 13, } # Create train op. train_op = gbdt_model.train(loss=math_ops.reduce_mean( losses.per_example_maxent_loss( labels, weights, predictions, num_classes=learner_config.num_classes)[0]), predictions_dict=predictions_dict, labels=labels) variables.global_variables_initializer().run() resources.initialize_resources(resources.shared_resources()).run() # On first run, expect no splits to be chosen because the quantile # buckets will not be ready. train_op.run() stamp_token, serialized = model_ops.tree_ensemble_serialize( ensemble_handle) output = tree_config_pb2.DecisionTreeEnsembleConfig() output.ParseFromString(serialized.eval()) self.assertEqual(len(output.trees), 0) self.assertEqual(len(output.tree_weights), 0) self.assertEqual(stamp_token.eval(), 1) # Update the stamp to be able to run a second time. sess.run([ensemble_stamp.assign_add(1)]) # On second run, expect a trivial split to be chosen to basically # predict the average. train_op.run() output = tree_config_pb2.DecisionTreeEnsembleConfig() output.ParseFromString(serialized.eval()) stamp_token, serialized = model_ops.tree_ensemble_serialize( ensemble_handle) output.ParseFromString(serialized.eval()) self.assertEqual(len(output.trees), 1) self.assertAllClose(output.tree_weights, [1]) self.assertEqual(stamp_token.eval(), 2) # One node for a split, two children nodes. self.assertEqual(3, len(output.trees[0].nodes)) # Leafs will have a sparse vector for class 3. self.assertEqual( 1, len(output.trees[0].nodes[1].leaf.sparse_vector.index)) self.assertEqual( 3, output.trees[0].nodes[1].leaf.sparse_vector.index[0]) self.assertAlmostEqual( -1.13134455681, output.trees[0].nodes[1].leaf.sparse_vector.value[0]) self.assertEqual( 1, len(output.trees[0].nodes[2].leaf.sparse_vector.index)) self.assertEqual( 3, output.trees[0].nodes[2].leaf.sparse_vector.index[0]) self.assertAlmostEqual( 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0])
def testTrainFnMulticlassDiagonalHessian(self): """Tests the GBDT train for multiclass diagonal hessian.""" with self.test_session() as sess: ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 1 # Use full hessian multiclass strategy. learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) learner_config.num_classes = 5 learner_config.regularization.l1 = 0 # To make matrix inversible. learner_config.regularization.l2 = 1e-5 learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 batch_size = 3 features = {} features["dense_float"] = array_ops.ones([batch_size, 1], dtypes.float32) gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( is_chief=True, num_ps_replicas=0, center_bias=False, ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], [0.0, 0.0, 0.0, 0.0, 1.2]], dtype=dtypes.float32) labels = array_ops.constant([[2], [2], [3]], dtype=dtypes.float32) weights = array_ops.ones([batch_size, 1], dtypes.float32) partition_ids = array_ops.zeros([batch_size], dtypes.int32) ensemble_stamp = variables.Variable(initial_value=0, name="ensemble_stamp", trainable=False, dtype=dtypes.int64) predictions_dict = { "predictions": predictions, "predictions_no_dropout": predictions, "partition_ids": partition_ids, "ensemble_stamp": ensemble_stamp, "num_trees": 0, } # Create train op. train_op = gbdt_model.train(loss=math_ops.reduce_mean( losses.per_example_maxent_loss( labels, weights, predictions, num_classes=learner_config.num_classes)[0]), predictions_dict=predictions_dict, labels=labels) variables.global_variables_initializer().run() resources.initialize_resources(resources.shared_resources()).run() # On first run, expect no splits to be chosen because the quantile # buckets will not be ready. train_op.run() stamp_token, serialized = model_ops.tree_ensemble_serialize( ensemble_handle) output = tree_config_pb2.DecisionTreeEnsembleConfig() output.ParseFromString(serialized.eval()) self.assertEqual(len(output.trees), 0) self.assertEqual(len(output.tree_weights), 0) self.assertEqual(stamp_token.eval(), 1) # Update the stamp to be able to run a second time. sess.run([ensemble_stamp.assign_add(1)]) # On second run, expect a trivial split to be chosen to basically # predict the average. train_op.run() output = tree_config_pb2.DecisionTreeEnsembleConfig() output.ParseFromString(serialized.eval()) stamp_token, serialized = model_ops.tree_ensemble_serialize( ensemble_handle) output.ParseFromString(serialized.eval()) self.assertEqual(len(output.trees), 1) self.assertAllClose(output.tree_weights, [1]) self.assertEqual(stamp_token.eval(), 2) # Leaf should have a dense vector of size 5. expected = [ -1.26767396927, -1.13043296337, 4.58542203903, 1.81428349018, -2.43038392067 ] for i in range(learner_config.num_classes): self.assertAlmostEqual( expected[i], output.trees[0].nodes[1].leaf.vector.value[i])
def testWithExistingEnsembleAndDropout(self): with self.test_session(): tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 10 trees with some weights. for i in range(0, 10): tree = tree_ensemble.trees.add() _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) tree_ensemble.tree_weights.append(i + 1) meta = tree_ensemble.tree_metadata.add() meta.num_tree_weight_updates = 1 tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config=tree_ensemble.SerializeToString(), name="existing") # Create non-zero feature importance. feature_usage_counts = variables.Variable( initial_value=np.array([2, 3], np.int64), name="feature_usage_counts", trainable=False) feature_gains = variables.Variable(initial_value=np.array( [0.0, 0.3], np.float32), name="feature_gains", trainable=False) resources.initialize_resources(resources.shared_resources()).run() variables.initialize_all_variables().run() dropped = [1, 6, 8] dropped_original_weights = [2.0, 7.0, 9.0] output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() with ops.control_dependencies([ ensemble_optimizer_ops.add_trees_to_ensemble( tree_ensemble_handle, self._ensemble_to_add.SerializeToString(), feature_usage_counts, [1, 2], feature_gains, [0.5, 0.3], [dropped, dropped_original_weights], learning_rate=0.1) ]): output_ensemble.ParseFromString( model_ops.tree_ensemble_serialize(tree_ensemble_handle) [1].eval()) # Output. self.assertEqual(11, len(output_ensemble.trees)) self.assertProtoEquals(self._tree_to_add, output_ensemble.trees[10]) self.assertAllClose(4.5, output_ensemble.tree_weights[10]) self.assertAllClose( [1., 1.5, 3., 4., 5., 6., 5.25, 8., 6.75, 10., 4.5], output_ensemble.tree_weights) self.assertEqual( 1, output_ensemble.tree_metadata[0].num_tree_weight_updates) self.assertEqual( 2, output_ensemble.tree_metadata[1].num_tree_weight_updates) self.assertEqual( 1, output_ensemble.tree_metadata[2].num_tree_weight_updates) self.assertEqual( 1, output_ensemble.tree_metadata[3].num_tree_weight_updates) self.assertEqual( 1, output_ensemble.tree_metadata[4].num_tree_weight_updates) self.assertEqual( 1, output_ensemble.tree_metadata[5].num_tree_weight_updates) self.assertEqual( 2, output_ensemble.tree_metadata[6].num_tree_weight_updates) self.assertEqual( 1, output_ensemble.tree_metadata[7].num_tree_weight_updates) self.assertEqual( 2, output_ensemble.tree_metadata[8].num_tree_weight_updates) self.assertEqual( 1, output_ensemble.tree_metadata[9].num_tree_weight_updates) self.assertEqual( 1, output_ensemble.tree_metadata[10].num_tree_weight_updates) self.assertAllEqual([3, 5], feature_usage_counts.eval()) self.assertArrayNear([0.05, 0.33], feature_gains.eval(), 1e-6)
def model_builder(features, labels, mode, params, config, output_type=ModelBuilderOutputType.MODEL_FN_OPS): """Multi-machine batch gradient descent tree model. Args: features: `Tensor` or `dict` of `Tensor` objects. labels: Labels used to train on. mode: Mode we are in. (TRAIN/EVAL/INFER) params: A dict of hyperparameters. The following hyperparameters are expected: * head: A `Head` instance. * learner_config: A config for the learner. * feature_columns: An iterable containing all the feature columns used by the model. * examples_per_layer: Number of examples to accumulate before growing a layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. * weight_column_name: The name of weight column. * center_bias: Whether a separate tree should be created for first fitting the bias. * override_global_step_value: If after the training is done, global step value must be reset to this value. This is particularly useful for hyper parameter tuning, which can't recognize early stopping due to the number of trees. If None, no override of global step will happen. config: `RunConfig` of the estimator. output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec (new interface). Returns: A `ModelFnOps` object. Raises: ValueError: if inputs are not valid. """ head = params["head"] learner_config = params["learner_config"] examples_per_layer = params["examples_per_layer"] feature_columns = params["feature_columns"] weight_column_name = params["weight_column_name"] num_trees = params["num_trees"] use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] override_global_step_value = params.get("override_global_step_value", None) num_quantiles = params["num_quantiles"] if features is None: raise ValueError("At least one feature must be specified.") if config is None: raise ValueError("Missing estimator RunConfig.") if config.session_config is not None: session_config = config.session_config session_config.allow_soft_placement = True else: session_config = config_pb2.ConfigProto(allow_soft_placement=True) config = config.replace(session_config=session_config) center_bias = params["center_bias"] if isinstance(features, ops.Tensor): features = {features.name: features} # Make a shallow copy of features to ensure downstream usage # is unaffected by modifications in the model function. training_features = copy.copy(features) training_features.pop(weight_column_name, None) global_step = training_util.get_global_step() initial_ensemble = "" if learner_config.each_tree_start.nodes: if learner_config.each_tree_start_num_layers <= 0: raise ValueError("You must provide each_tree_start_num_layers.") num_layers = learner_config.each_tree_start_num_layers initial_ensemble = """ trees { %s } tree_weights: 0.1 tree_metadata { num_tree_weight_updates: 1 num_layers_grown: %d is_finalized: false } """ % (text_format.MessageToString( learner_config.each_tree_start), num_layers) tree_ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge(initial_ensemble, tree_ensemble_proto) initial_ensemble = tree_ensemble_proto.SerializeToString() with ops.device(global_step.device): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config=initial_ensemble, # Initialize the ensemble. name="ensemble_model") # Create GBDT model. gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( is_chief=config.is_chief, num_ps_replicas=config.num_ps_replicas, ensemble_handle=ensemble_handle, center_bias=center_bias, examples_per_layer=examples_per_layer, learner_config=learner_config, feature_columns=feature_columns, logits_dimension=head.logits_dimension, features=training_features, use_core_columns=use_core_libs, output_leaf_index=output_leaf_index, num_quantiles=num_quantiles) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] if logits_modifier_function: logits = logits_modifier_function(logits, features, mode) def _train_op_fn(loss): """Returns the op to optimize the loss.""" update_op = gbdt_model.train(loss, predictions_dict, labels) with ops.control_dependencies( [update_op]), (ops.colocate_with(global_step)): update_op = state_ops.assign_add(global_step, 1).op return update_op create_estimator_spec_op = getattr(head, "create_estimator_spec", None) training_hooks = [] if num_trees: if center_bias: num_trees += 1 finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor( ) training_hooks.append( trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees, override_global_step_value)) if output_type == ModelBuilderOutputType.MODEL_FN_OPS: if use_core_libs and callable(create_estimator_spec_op): model_fn_ops = head.create_estimator_spec(features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops( model_fn_ops) else: model_fn_ops = head.create_model_fn_ops(features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ gbdt_batch.LEAF_INDEX] model_fn_ops.training_hooks.extend(training_hooks) return model_fn_ops elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC: assert callable(create_estimator_spec_op) estimator_spec = head.create_estimator_spec(features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: estimator_spec.predictions[ gbdt_batch.LEAF_INDEX] = predictions_dict[ gbdt_batch.LEAF_INDEX] estimator_spec = estimator_spec._replace( training_hooks=training_hooks + list(estimator_spec.training_hooks)) return estimator_spec return model_fn_ops
def testTrainFnChiefWithBiasCentering(self): """Tests the train function running on chief with bias centering.""" with self.test_session(): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, tree_ensemble_config="", name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 learner_config.num_classes = 2 learner_config.regularization.l1 = 0 learner_config.regularization.l2 = 0 learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 features = {} features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( is_chief=True, num_ps_replicas=0, center_bias=True, ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, features=features) predictions = array_ops.constant([[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) ensemble_stamp = variables.Variable(initial_value=0, name="ensemble_stamp", trainable=False, dtype=dtypes.int64) predictions_dict = { "predictions": predictions, "predictions_no_dropout": predictions, "partition_ids": partition_ids, "ensemble_stamp": ensemble_stamp, "num_trees": 12, } labels = array_ops.ones([4, 1], dtypes.float32) weights = array_ops.ones([4, 1], dtypes.float32) # Create train op. train_op = gbdt_model.train(loss=math_ops.reduce_mean( _squared_loss(labels, weights, predictions)), predictions_dict=predictions_dict, labels=labels) variables.global_variables_initializer().run() resources.initialize_resources(resources.shared_resources()).run() # On first run, expect bias to be centered. train_op.run() stamp_token, serialized = model_ops.tree_ensemble_serialize( ensemble_handle) output = tree_config_pb2.DecisionTreeEnsembleConfig() output.ParseFromString(serialized.eval()) expected_tree = """ nodes { leaf { vector { value: 0.25 } } }""" self.assertEquals(len(output.trees), 1) self.assertAllEqual(output.tree_weights, [1.0]) self.assertProtoEquals(expected_tree, output.trees[0]) self.assertEquals(stamp_token.eval(), 1)