def testCachedPredictionOnEmptyEnsemble(self): """Tests that prediction on a dummy ensemble does not fail.""" with self.cached_session() as session: # Create a dummy ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto='') tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() # No previous cached values. cached_tree_ids = [0, 0] cached_node_ids = [0, 0] # We have two features: 0 and 1. Values don't matter here on a dummy # ensemble. feature_0_values = [67, 5] feature_1_values = [9, 17] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run(predict_op) # Nothing changed. self.assertAllClose(cached_tree_ids, new_tree_ids) self.assertAllClose(cached_node_ids, new_node_ids) self.assertAllClose([[0], [0]], logits_updates)
def testCachedPredictionFromPreviousTree(self): """Tests the predictions work when we have cache from previous trees.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge(""" trees { nodes { bucketized_split { feature_id: 1 threshold: 28 left_id: 1 right_id: 2 } metadata { gain: 7.62 } } nodes { leaf { scalar: 1.14 } } nodes { leaf { scalar: 8.79 } } } trees { nodes { bucketized_split { feature_id: 1 threshold: 26 left_id: 1 right_id: 2 } } nodes { bucketized_split { feature_id: 0 threshold: 50 left_id: 3 right_id: 4 } } nodes { leaf { scalar: 7 } } nodes { leaf { scalar: 5 } } nodes { leaf { scalar: 6 } } } trees { nodes { bucketized_split { feature_id: 0 threshold: 34 left_id: 1 right_id: 2 } } nodes { leaf { scalar: -7.0 } } nodes { leaf { scalar: 5.0 } } } tree_metadata { is_finalized: true } tree_metadata { is_finalized: true } tree_metadata { is_finalized: false } tree_weights: 0.1 tree_weights: 0.1 tree_weights: 0.1 """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() # Two examples, one was cached in node 1 first, another in node 2. cached_tree_ids = [0, 0] cached_node_ids = [1, 0] # We have two features: 0 and 1. feature_0_values = [36, 32] feature_1_values = [11, 27] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run(predict_op) # Example 1 will get to node 3 in tree 1 and node 2 of tree 2 # Example 2 will get to node 2 in tree 1 and node 1 of tree 2 # We are in the last tree. self.assertAllClose([2, 2], new_tree_ids) # When using the full tree, the first example will end up in node 4, # the second in node 5. self.assertAllClose([2, 1], new_node_ids) # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = > # change = 0.1*(5.0+5.0) # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = > # change= 0.1(1.14+7.0-7.0) self.assertAllClose([[1], [0.114]], logits_updates)
def testCachedPredictionIsCurrent(self): """Tests that prediction based on previous node in the tree works.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge(""" trees { nodes { bucketized_split { feature_id: 1 threshold: 15 left_id: 1 right_id: 2 } metadata { gain: 7.62 original_leaf { scalar: -2 } } } nodes { leaf { scalar: 1.14 } } nodes { leaf { scalar: 8.79 } } } tree_weights: 0.1 tree_metadata { is_finalized: true num_layers_grown: 2 } growing_metadata { num_trees_attempted: 1 num_layers_attempted: 2 } """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() # Two examples, one was cached in node 1 first, another in node 0. cached_tree_ids = [0, 0] cached_node_ids = [1, 2] # We have two features: 0 and 1. Values don't matter because trees didn't # change. feature_0_values = [67, 5] feature_1_values = [9, 17] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run(predict_op) # Nothing changed. self.assertAllClose(cached_tree_ids, new_tree_ids) self.assertAllClose(cached_node_ids, new_node_ids) self.assertAllClose([[0], [0]], logits_updates)
def testCachedPredictionFromTheSameTree(self): """Tests that prediction based on previous node in the tree works.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge(""" trees { nodes { bucketized_split { feature_id: 1 threshold: 15 left_id: 1 right_id: 2 } metadata { gain: 7.62 original_leaf { scalar: -2 } } } nodes { bucketized_split { feature_id: 1 threshold: 7 left_id: 3 right_id: 4 } metadata { gain: 1.4 original_leaf { scalar: 7.14 } } } nodes { bucketized_split { feature_id: 0 threshold: 7 left_id: 5 right_id: 6 } metadata { gain: 2.7 original_leaf { scalar: -4.375 } } } nodes { leaf { scalar: 1.14 } } nodes { leaf { scalar: 8.79 } } nodes { leaf { scalar: -5.875 } } nodes { leaf { scalar: -2.075 } } } tree_weights: 0.1 tree_metadata { is_finalized: true num_layers_grown: 2 } growing_metadata { num_trees_attempted: 1 num_layers_attempted: 2 } """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() # Two examples, one was cached in node 1 first, another in node 0. cached_tree_ids = [0, 0] cached_node_ids = [1, 0] # We have two features: 0 and 1. feature_0_values = [67, 5] feature_1_values = [9, 17] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run(predict_op) # We are still in the same tree. self.assertAllClose([0, 0], new_tree_ids) # When using the full tree, the first example will end up in node 4, # the second in node 5. self.assertAllClose([4, 5], new_node_ids) # Full predictions for each instance would be 8.79 and -5.875, # so an update from the previous cached values lr*(7.14 and -2) would be # 1.65 and -3.875, and then multiply them by 0.1 (lr) self.assertAllClose([[0.1 * 1.65], [0.1 * -3.875]], logits_updates)
def _bt_model_fn( features, labels, mode, head, feature_columns, tree_hparams, n_batches_per_layer, config, closed_form_grad_and_hess_fn=None, example_id_column_name=None, # TODO(youngheek): replace this later using other options. train_in_memory=False, name='boosted_trees'): """Gradient Boosted Trees model_fn. Args: features: dict of `Tensor`. labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype `int32` or `int64` in the range `[0, n_classes)`. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. head: A `head_lib._Head` instance. feature_columns: Iterable of `feature_column._FeatureColumn` model inputs. tree_hparams: TODO. collections.namedtuple for hyper parameters. n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at least n_batches_per_layer accumulations. config: `RunConfig` object to configure the runtime settings. closed_form_grad_and_hess_fn: a function that accepts logits and labels and returns gradients and hessians. By default, they are created by tf.gradients() from the loss. example_id_column_name: Name of the feature for a unique ID per example. Currently experimental -- not exposed to public API. train_in_memory: `bool`, when true, it assumes the dataset is in memory, i.e., input_fn should return the entire dataset as a single batch, and also n_batches_per_layer should be set as 1. name: Name to use for the model. Returns: An `EstimatorSpec` instance. Raises: ValueError: mode or params are invalid, or features has the wrong type. """ is_single_machine = (config.num_worker_replicas <= 1) sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name) center_bias = tree_hparams.center_bias if train_in_memory: assert n_batches_per_layer == 1, ( 'When train_in_memory is enabled, input_fn should return the entire ' 'dataset as a single batch, and n_batches_per_layer should be set as ' '1.') if (not config.is_chief or config.num_worker_replicas > 1 or config.num_ps_replicas > 0): raise ValueError('train_in_memory is supported only for ' 'non-distributed training.') worker_device = control_flow_ops.no_op().device train_op = [] with ops.name_scope(name) as name: # Prepare. global_step = training_util.get_or_create_global_step() bucket_size_list, feature_ids_list = _group_features_by_num_buckets( sorted_feature_columns) # Create Ensemble resources. tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) # Create logits. if mode != model_fn.ModeKeys.TRAIN: input_feature_list = _get_transformed_features(features, sorted_feature_columns) logits = boosted_trees_ops.predict( # For non-TRAIN mode, ensemble doesn't change after initialization, # so no local copy is needed; using tree_ensemble directly. tree_ensemble_handle=tree_ensemble.resource_handle, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) return head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=control_flow_ops.no_op, logits=logits) # ============== Training graph ============== # Extract input features and set up cache for training. training_state_cache = None if train_in_memory: # cache transformed features as well for in-memory training. batch_size = array_ops.shape(labels)[0] input_feature_list, input_cache_op = ( _cache_transformed_features(features, sorted_feature_columns, batch_size)) train_op.append(input_cache_op) training_state_cache = _CacheTrainingStatesUsingVariables( batch_size, head.logits_dimension) else: input_feature_list = _get_transformed_features(features, sorted_feature_columns) if example_id_column_name: example_ids = features[example_id_column_name] training_state_cache = _CacheTrainingStatesUsingHashTable( example_ids, head.logits_dimension) # Variable that determines whether bias centering is needed. center_bias_var = variable_scope.variable( initial_value=center_bias, name='center_bias_needed', trainable=False) if is_single_machine: local_tree_ensemble = tree_ensemble ensemble_reload = control_flow_ops.no_op() else: # Have a local copy of ensemble for the distributed setting. with ops.device(worker_device): local_tree_ensemble = boosted_trees_ops.TreeEnsemble( name=name + '_local', is_local=True) # TODO(soroush): Do partial updates if this becomes a bottleneck. ensemble_reload = local_tree_ensemble.deserialize( *tree_ensemble.serialize()) if training_state_cache: cached_tree_ids, cached_node_ids, cached_logits = ( training_state_cache.lookup()) else: # Always start from the beginning when no cache is set up. batch_size = array_ops.shape(labels)[0] cached_tree_ids, cached_node_ids, cached_logits = ( array_ops.zeros([batch_size], dtype=dtypes.int32), _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32), array_ops.zeros( [batch_size, head.logits_dimension], dtype=dtypes.float32)) with ops.control_dependencies([ensemble_reload]): (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, last_layer_nodes_range) = local_tree_ensemble.get_states() summary.scalar('ensemble/num_trees', num_trees) summary.scalar('ensemble/num_finalized_trees', num_finalized_trees) summary.scalar('ensemble/num_attempted_layers', num_attempted_layers) partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict( tree_ensemble_handle=local_tree_ensemble.resource_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) logits = cached_logits + partial_logits # Create training graph. def _train_op_fn(loss): """Run one training iteration.""" if training_state_cache: # Cache logits only after center_bias is complete, if it's in progress. train_op.append( control_flow_ops.cond( center_bias_var, control_flow_ops.no_op, lambda: training_state_cache.insert(tree_ids, node_ids, logits)) ) if closed_form_grad_and_hess_fn: gradients, hessians = closed_form_grad_and_hess_fn(logits, labels) else: gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0] hessians = gradients_impl.gradients( gradients, logits, name='Hessians')[0] # TODO(youngheek): perhaps storage could be optimized by storing stats # with the dimension max_splits_per_layer, instead of max_splits (for the # entire tree). max_splits = _get_max_splits(tree_hparams) stats_summaries_list = [] for i, feature_ids in enumerate(feature_ids_list): num_buckets = bucket_size_list[i] summaries = [ array_ops.squeeze( boosted_trees_ops.make_stats_summary( node_ids=node_ids, gradients=gradients, hessians=hessians, bucketized_features_list=[input_feature_list[f]], max_splits=max_splits, num_buckets=num_buckets), axis=0) for f in feature_ids ] stats_summaries_list.append(summaries) if train_in_memory and is_single_machine: grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams) else: grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams, stamp_token, n_batches_per_layer, bucket_size_list, config.is_chief) update_model = control_flow_ops.cond( center_bias_var, functools.partial( grower.center_bias, center_bias_var, gradients, hessians, ), functools.partial(grower.grow_tree, stats_summaries_list, feature_ids_list, last_layer_nodes_range)) train_op.append(update_model) with ops.control_dependencies([update_model]): increment_global = distribute_lib.increment_var(global_step) train_op.append(increment_global) return control_flow_ops.group(train_op, name='train_op') estimator_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) # Add an early stop hook. estimator_spec = estimator_spec._replace( training_hooks=estimator_spec.training_hooks + (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers, tree_hparams.n_trees, tree_hparams.max_depth),)) return estimator_spec
def _bt_model_fn( features, labels, mode, head, feature_columns, tree_hparams, n_batches_per_layer, config, closed_form_grad_and_hess_fn=None, example_id_column_name=None, # TODO(youngheek): replace this later using other options. train_in_memory=False, name='boosted_trees'): """Gradient Boosted Trees model_fn. Args: features: dict of `Tensor`. labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype `int32` or `int64` in the range `[0, n_classes)`. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. head: A `head_lib._Head` instance. feature_columns: Iterable of `feature_column._FeatureColumn` model inputs. tree_hparams: TODO. collections.namedtuple for hyper parameters. n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at least n_batches_per_layer accumulations. config: `RunConfig` object to configure the runtime settings. closed_form_grad_and_hess_fn: a function that accepts logits and labels and returns gradients and hessians. By default, they are created by tf.gradients() from the loss. example_id_column_name: Name of the feature for a unique ID per example. Currently experimental -- not exposed to public API. train_in_memory: `bool`, when true, it assumes the dataset is in memory, i.e., input_fn should return the entire dataset as a single batch, and also n_batches_per_layer should be set as 1. name: Name to use for the model. Returns: An `EstimatorSpec` instance. Raises: ValueError: mode or params are invalid, or features has the wrong type. """ is_single_machine = (config.num_worker_replicas <= 1) sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name) if train_in_memory: assert n_batches_per_layer == 1, ( 'When train_in_memory is enabled, input_fn should return the entire ' 'dataset as a single batch, and n_batches_per_layer should be set as ' '1.') if (not config.is_chief or config.num_worker_replicas > 1 or config.num_ps_replicas > 0): raise ValueError('train_in_memory is supported only for ' 'non-distributed training.') worker_device = control_flow_ops.no_op().device # maximum number of splits possible in the whole tree =2^(D-1)-1 # TODO(youngheek): perhaps storage could be optimized by storing stats with # the dimension max_splits_per_layer, instead of max_splits (for the entire # tree). max_splits = (1 << tree_hparams.max_depth) - 1 train_op = [] with ops.name_scope(name) as name: # Prepare. global_step = training_util.get_or_create_global_step() bucket_size_list, feature_ids_list = _group_features_by_num_buckets( sorted_feature_columns) # Extract input features and set up cache for training. training_state_cache = None if mode == model_fn.ModeKeys.TRAIN and train_in_memory: # cache transformed features as well for in-memory training. batch_size = array_ops.shape(labels)[0] input_feature_list, input_cache_op = ( _cache_transformed_features(features, sorted_feature_columns, batch_size)) train_op.append(input_cache_op) training_state_cache = _CacheTrainingStatesUsingVariables( batch_size, head.logits_dimension) else: input_feature_list = _get_transformed_features(features, sorted_feature_columns) if mode == model_fn.ModeKeys.TRAIN and example_id_column_name: example_ids = features[example_id_column_name] training_state_cache = _CacheTrainingStatesUsingHashTable( example_ids, head.logits_dimension) # Create Ensemble resources. tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) # Create logits. if mode != model_fn.ModeKeys.TRAIN: logits = boosted_trees_ops.predict( # For non-TRAIN mode, ensemble doesn't change after initialization, # so no local copy is needed; using tree_ensemble directly. tree_ensemble_handle=tree_ensemble.resource_handle, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) else: if is_single_machine: local_tree_ensemble = tree_ensemble ensemble_reload = control_flow_ops.no_op() else: # Have a local copy of ensemble for the distributed setting. with ops.device(worker_device): local_tree_ensemble = boosted_trees_ops.TreeEnsemble( name=name + '_local', is_local=True) # TODO(soroush): Do partial updates if this becomes a bottleneck. ensemble_reload = local_tree_ensemble.deserialize( *tree_ensemble.serialize()) if training_state_cache: cached_tree_ids, cached_node_ids, cached_logits = ( training_state_cache.lookup()) else: # Always start from the beginning when no cache is set up. batch_size = array_ops.shape(labels)[0] cached_tree_ids, cached_node_ids, cached_logits = ( array_ops.zeros([batch_size], dtype=dtypes.int32), array_ops.zeros([batch_size], dtype=dtypes.int32), array_ops.zeros( [batch_size, head.logits_dimension], dtype=dtypes.float32)) with ops.control_dependencies([ensemble_reload]): (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, last_layer_nodes_range) = local_tree_ensemble.get_states() summary.scalar('ensemble/num_trees', num_trees) summary.scalar('ensemble/num_finalized_trees', num_finalized_trees) summary.scalar('ensemble/num_attempted_layers', num_attempted_layers) partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict( tree_ensemble_handle=local_tree_ensemble.resource_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) logits = cached_logits + partial_logits # Create training graph. def _train_op_fn(loss): """Run one training iteration.""" if training_state_cache: train_op.append(training_state_cache.insert(tree_ids, node_ids, logits)) if closed_form_grad_and_hess_fn: gradients, hessians = closed_form_grad_and_hess_fn(logits, labels) else: gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0] hessians = gradients_impl.gradients( gradients, logits, name='Hessians')[0] stats_summaries_list = [] for i, feature_ids in enumerate(feature_ids_list): num_buckets = bucket_size_list[i] summaries = [ array_ops.squeeze( boosted_trees_ops.make_stats_summary( node_ids=node_ids, gradients=gradients, hessians=hessians, bucketized_features_list=[input_feature_list[f]], max_splits=max_splits, num_buckets=num_buckets), axis=0) for f in feature_ids ] stats_summaries_list.append(summaries) accumulators = [] def grow_tree_from_stats_summaries(stats_summaries_list, feature_ids_list): """Updates ensemble based on the best gains from stats summaries.""" node_ids_per_feature = [] gains_list = [] thresholds_list = [] left_node_contribs_list = [] right_node_contribs_list = [] all_feature_ids = [] assert len(stats_summaries_list) == len(feature_ids_list) for i, feature_ids in enumerate(feature_ids_list): (numeric_node_ids_per_feature, numeric_gains_list, numeric_thresholds_list, numeric_left_node_contribs_list, numeric_right_node_contribs_list) = ( boosted_trees_ops.calculate_best_gains_per_feature( node_id_range=last_layer_nodes_range, stats_summary_list=stats_summaries_list[i], l1=tree_hparams.l1, l2=tree_hparams.l2, tree_complexity=tree_hparams.tree_complexity, min_node_weight=tree_hparams.min_node_weight, max_splits=max_splits)) all_feature_ids += feature_ids node_ids_per_feature += numeric_node_ids_per_feature gains_list += numeric_gains_list thresholds_list += numeric_thresholds_list left_node_contribs_list += numeric_left_node_contribs_list right_node_contribs_list += numeric_right_node_contribs_list grow_op = boosted_trees_ops.update_ensemble( # Confirm if local_tree_ensemble or tree_ensemble should be used. tree_ensemble.resource_handle, feature_ids=all_feature_ids, node_ids=node_ids_per_feature, gains=gains_list, thresholds=thresholds_list, left_node_contribs=left_node_contribs_list, right_node_contribs=right_node_contribs_list, learning_rate=tree_hparams.learning_rate, max_depth=tree_hparams.max_depth, pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING) return grow_op if train_in_memory and is_single_machine: train_op.append(distribute_lib.increment_var(global_step)) train_op.append( grow_tree_from_stats_summaries(stats_summaries_list, feature_ids_list)) else: dependencies = [] for i, feature_ids in enumerate(feature_ids_list): stats_summaries = stats_summaries_list[i] accumulator = data_flow_ops.ConditionalAccumulator( dtype=dtypes.float32, # The stats consist of grads and hessians (the last dimension). shape=[len(feature_ids), max_splits, bucket_size_list[i], 2], shared_name='numeric_stats_summary_accumulator_' + str(i)) accumulators.append(accumulator) apply_grad = accumulator.apply_grad( array_ops.stack(stats_summaries, axis=0), stamp_token) dependencies.append(apply_grad) def grow_tree_from_accumulated_summaries_fn(): """Updates the tree with the best layer from accumulated summaries.""" # Take out the accumulated summaries from the accumulator and grow. stats_summaries_list = [] stats_summaries_list = [ array_ops.unstack(accumulator.take_grad(1), axis=0) for accumulator in accumulators ] grow_op = grow_tree_from_stats_summaries(stats_summaries_list, feature_ids_list) return grow_op with ops.control_dependencies(dependencies): train_op.append(distribute_lib.increment_var(global_step)) if config.is_chief: min_accumulated = math_ops.reduce_min( array_ops.stack( [acc.num_accumulated() for acc in accumulators])) train_op.append( control_flow_ops.cond( math_ops.greater_equal(min_accumulated, n_batches_per_layer), grow_tree_from_accumulated_summaries_fn, control_flow_ops.no_op, name='wait_until_n_batches_accumulated')) return control_flow_ops.group(train_op, name='train_op') estimator_spec = head.create_estimator_spec( features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) if mode == model_fn.ModeKeys.TRAIN: # Add an early stop hook. estimator_spec = estimator_spec._replace( training_hooks=estimator_spec.training_hooks + (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers, tree_hparams.n_trees, tree_hparams.max_depth),)) return estimator_spec
def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self): """Tests that prediction based on previous node in the tree works.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge( """ trees { nodes { bucketized_split { feature_id:0 threshold: 33 left_id: 1 right_id: 2 } metadata { gain: -0.2 } } nodes { leaf { scalar: 0.01 } } nodes { bucketized_split { feature_id: 1 threshold: 5 left_id: 3 right_id: 4 } metadata { gain: 0.5 original_leaf { scalar: 0.0143 } } } nodes { leaf { scalar: 0.0553 } } nodes { leaf { scalar: 0.0783 } } } trees { nodes { leaf { scalar: 0.55 } } } tree_weights: 1.0 tree_weights: 1.0 tree_metadata { num_layers_grown: 3 is_finalized: true post_pruned_nodes_meta { new_node_id: 0 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 1 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 2 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 1 logit_change: -0.07 } post_pruned_nodes_meta { new_node_id: 1 logit_change: -0.083 } post_pruned_nodes_meta { new_node_id: 3 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 4 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 1 logit_change: -0.22 } post_pruned_nodes_meta { new_node_id: 1 logit_change: -0.57 } } tree_metadata { num_layers_grown: 1 is_finalized: false } growing_metadata { num_trees_attempted: 2 num_layers_attempted: 4 } """, tree_ensemble_config) # Create existing ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() cached_tree_ids = [0, 0, 0, 0, 0, 0] # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6 # changed the ids to 3 and 4 respectively. cached_node_ids = [3, 4, 5, 6, 7, 8] # We have two features: 0 and 1. feature_0_values = [12, 17, 35, 36, 23, 11] feature_1_values = [12, 12, 17, 18, 123, 24] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run( predict_op) # We are in the last tree. self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids) # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5 # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in # the root node. self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids) cached_values = [[0.08], [0.093], [0.0553], [0.0783], [0.15 + 0.08], [0.5 + 0.08]] root = 0.55 self.assertAllClose( [[root + 0.01], [root + 0.01], [root + 0.0553], [root + 0.0783], [root + 0.01], [root + 0.01]], logits_updates + cached_values)
def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self): """Tests that prediction based on previous node in the tree works.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge(""" trees { nodes { bucketized_split { feature_id:0 threshold: 33 left_id: 1 right_id: 2 } metadata { gain: -0.2 } } nodes { leaf { scalar: 0.01 } } nodes { bucketized_split { feature_id: 1 threshold: 5 left_id: 3 right_id: 4 } metadata { gain: 0.5 original_leaf { scalar: 0.0143 } } } nodes { leaf { scalar: 0.0553 } } nodes { leaf { scalar: 0.0783 } } } trees { nodes { leaf { scalar: 0.55 } } } tree_weights: 1.0 tree_weights: 1.0 tree_metadata { num_layers_grown: 3 is_finalized: true post_pruned_nodes_meta { new_node_id: 0 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 1 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 2 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 1 logit_change: -0.07 } post_pruned_nodes_meta { new_node_id: 1 logit_change: -0.083 } post_pruned_nodes_meta { new_node_id: 3 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 4 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 1 logit_change: -0.22 } post_pruned_nodes_meta { new_node_id: 1 logit_change: -0.57 } } tree_metadata { num_layers_grown: 1 is_finalized: false } growing_metadata { num_trees_attempted: 2 num_layers_attempted: 4 } """, tree_ensemble_config) # Create existing ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() cached_tree_ids = [0, 0, 0, 0, 0, 0] # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6 # changed the ids to 3 and 4 respectively. cached_node_ids = [3, 4, 5, 6, 7, 8] # We have two features: 0 and 1. feature_0_values = [12, 17, 35, 36, 23, 11] feature_1_values = [12, 12, 17, 18, 123, 24] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run(predict_op) # We are in the last tree. self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids) # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5 # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in # the root node. self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids) cached_values = [[0.08], [0.093], [0.0553], [0.0783], [0.15 + 0.08], [0.5 + 0.08]] root = 0.55 self.assertAllClose([[root + 0.01], [root + 0.01], [root + 0.0553], [root + 0.0783], [root + 0.01], [root + 0.01]], logits_updates + cached_values)
def testCategoricalSplits(self): """Tests the training prediction work for categorical splits.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge( """ trees { nodes { categorical_split { feature_id: 1 value: 2 left_id: 1 right_id: 2 } } nodes { categorical_split { feature_id: 0 value: 13 left_id: 3 right_id: 4 } } nodes { leaf { scalar: 7.0 } } nodes { leaf { scalar: 5.0 } } nodes { leaf { scalar: 6.0 } } } tree_weights: 1.0 tree_metadata { is_finalized: true } """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() feature_0_values = [13, 1, 3] feature_1_values = [2, 2, 1] # No previous cached values. cached_tree_ids = [0, 0, 0] cached_node_ids = [0, 0, 0] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run( predict_op) self.assertAllClose([0, 0, 0], new_tree_ids) self.assertAllClose([3, 4, 2], new_node_ids) self.assertAllClose([[5.], [6.], [7.]], logits_updates)
def testNoCachedPredictionButTreeExists(self): """Tests that predictions are updated once trees are added.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge( """ trees { nodes { bucketized_split { feature_id: 0 threshold: 15 left_id: 1 right_id: 2 } metadata { gain: 7.62 } } nodes { leaf { scalar: 1.14 } } nodes { leaf { scalar: 8.79 } } } tree_weights: 0.1 tree_metadata { is_finalized: true num_layers_grown: 1 } growing_metadata { num_trees_attempted: 1 num_layers_attempted: 2 } """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() # Two examples, none were cached before. cached_tree_ids = [0, 0] cached_node_ids = [0, 0] feature_0_values = [67, 5] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run( predict_op) # We are in the first tree. self.assertAllClose([0, 0], new_tree_ids) self.assertAllClose([2, 1], new_node_ids) self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
def testCachedPredictionFromPreviousTree(self): """Tests the predictions work when we have cache from previous trees.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge( """ trees { nodes { bucketized_split { feature_id: 1 threshold: 28 left_id: 1 right_id: 2 } metadata { gain: 7.62 } } nodes { leaf { scalar: 1.14 } } nodes { leaf { scalar: 8.79 } } } trees { nodes { bucketized_split { feature_id: 1 threshold: 26 left_id: 1 right_id: 2 } } nodes { bucketized_split { feature_id: 0 threshold: 50 left_id: 3 right_id: 4 } } nodes { leaf { scalar: 7 } } nodes { leaf { scalar: 5 } } nodes { leaf { scalar: 6 } } } trees { nodes { bucketized_split { feature_id: 0 threshold: 34 left_id: 1 right_id: 2 } } nodes { leaf { scalar: -7.0 } } nodes { leaf { scalar: 5.0 } } } tree_metadata { is_finalized: true } tree_metadata { is_finalized: true } tree_metadata { is_finalized: false } tree_weights: 0.1 tree_weights: 0.1 tree_weights: 0.1 """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() # Two examples, one was cached in node 1 first, another in node 2. cached_tree_ids = [0, 0] cached_node_ids = [1, 0] # We have two features: 0 and 1. feature_0_values = [36, 32] feature_1_values = [11, 27] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run( predict_op) # Example 1 will get to node 3 in tree 1 and node 2 of tree 2 # Example 2 will get to node 2 in tree 1 and node 1 of tree 2 # We are in the last tree. self.assertAllClose([2, 2], new_tree_ids) # When using the full tree, the first example will end up in node 4, # the second in node 5. self.assertAllClose([2, 1], new_node_ids) # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = > # change = 0.1*(5.0+5.0) # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = > # change= 0.1(1.14+7.0-7.0) self.assertAllClose([[1], [0.114]], logits_updates)
def testCachedPredictionFromTheSameTree(self): """Tests that prediction based on previous node in the tree works.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge( """ trees { nodes { bucketized_split { feature_id: 1 threshold: 15 left_id: 1 right_id: 2 } metadata { gain: 7.62 original_leaf { scalar: -2 } } } nodes { bucketized_split { feature_id: 1 threshold: 7 left_id: 3 right_id: 4 } metadata { gain: 1.4 original_leaf { scalar: 7.14 } } } nodes { bucketized_split { feature_id: 0 threshold: 7 left_id: 5 right_id: 6 } metadata { gain: 2.7 original_leaf { scalar: -4.375 } } } nodes { leaf { scalar: 1.14 } } nodes { leaf { scalar: 8.79 } } nodes { leaf { scalar: -5.875 } } nodes { leaf { scalar: -2.075 } } } tree_weights: 0.1 tree_metadata { is_finalized: true num_layers_grown: 2 } growing_metadata { num_trees_attempted: 1 num_layers_attempted: 2 } """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() # Two examples, one was cached in node 1 first, another in node 0. cached_tree_ids = [0, 0] cached_node_ids = [1, 0] # We have two features: 0 and 1. feature_0_values = [67, 5] feature_1_values = [9, 17] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run( predict_op) # We are still in the same tree. self.assertAllClose([0, 0], new_tree_ids) # When using the full tree, the first example will end up in node 4, # the second in node 5. self.assertAllClose([4, 5], new_node_ids) # Full predictions for each instance would be 8.79 and -5.875, # so an update from the previous cached values lr*(7.14 and -2) would be # 1.65 and -3.875, and then multiply them by 0.1 (lr) self.assertAllClose([[0.1 * 1.65], [0.1 * -3.875]], logits_updates)
def testCachedPredictionIsCurrent(self): """Tests that prediction based on previous node in the tree works.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge( """ trees { nodes { bucketized_split { feature_id: 1 threshold: 15 left_id: 1 right_id: 2 } metadata { gain: 7.62 original_leaf { scalar: -2 } } } nodes { leaf { scalar: 1.14 } } nodes { leaf { scalar: 8.79 } } } tree_weights: 0.1 tree_metadata { is_finalized: true num_layers_grown: 2 } growing_metadata { num_trees_attempted: 1 num_layers_attempted: 2 } """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() # Two examples, one was cached in node 1 first, another in node 0. cached_tree_ids = [0, 0] cached_node_ids = [1, 2] # We have two features: 0 and 1. Values don't matter because trees didn't # change. feature_0_values = [67, 5] feature_1_values = [9, 17] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run( predict_op) # Nothing changed. self.assertAllClose(cached_tree_ids, new_tree_ids) self.assertAllClose(cached_node_ids, new_node_ids) self.assertAllClose([[0], [0]], logits_updates)
def testCategoricalSplits(self): """Tests the training prediction work for categorical splits.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge( """ trees { nodes { categorical_split { feature_id: 1 value: 2 left_id: 1 right_id: 2 } } nodes { categorical_split { feature_id: 0 value: 13 left_id: 3 right_id: 4 } } nodes { leaf { scalar: 7.0 } } nodes { leaf { scalar: 5.0 } } nodes { leaf { scalar: 6.0 } } } tree_weights: 1.0 tree_metadata { is_finalized: true } """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() feature_0_values = [13, 1, 3] feature_1_values = [2, 2, 1] # No previous cached values. cached_tree_ids = [0, 0, 0] cached_node_ids = [0, 0, 0] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run(predict_op) self.assertAllClose([0, 0, 0], new_tree_ids) self.assertAllClose([3, 4, 2], new_node_ids) self.assertAllClose([[5.], [6.], [7.]], logits_updates)
def testCachedPredictionTheWholeTreeWasPruned(self): """Tests that prediction based on previous node in the tree works.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge( """ trees { nodes { leaf { scalar: 0.00 } } } tree_weights: 1.0 tree_metadata { num_layers_grown: 1 is_finalized: true post_pruned_nodes_meta { new_node_id: 0 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 0 logit_change: -6.0 } post_pruned_nodes_meta { new_node_id: 0 logit_change: 5.0 } } growing_metadata { num_trees_attempted: 1 num_layers_attempted: 1 } """, tree_ensemble_config) # Create existing ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() cached_tree_ids = [ 0, 0, ] # The predictions were cached in 1 and 2, both were pruned to the root. cached_node_ids = [1, 2] # We have two features: 0 and 1.These are not going to be used anywhere. feature_0_values = [12, 17] feature_1_values = [12, 12] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run( predict_op) # We are in the last tree. self.assertAllClose([0, 0], new_tree_ids) self.assertAllClose([0, 0], new_node_ids) self.assertAllClose([[-6.0], [5.0]], logits_updates)
def testNoCachedPredictionButTreeExists(self): """Tests that predictions are updated once trees are added.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge(""" trees { nodes { bucketized_split { feature_id: 0 threshold: 15 left_id: 1 right_id: 2 } metadata { gain: 7.62 } } nodes { leaf { scalar: 1.14 } } nodes { leaf { scalar: 8.79 } } } tree_weights: 0.1 tree_metadata { is_finalized: true num_layers_grown: 1 } growing_metadata { num_trees_attempted: 1 num_layers_attempted: 2 } """, tree_ensemble_config) # Create existing ensemble with one root split tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() # Two examples, none were cached before. cached_tree_ids = [0, 0] cached_node_ids = [0, 0] feature_0_values = [67, 5] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run(predict_op) # We are in the first tree. self.assertAllClose([0, 0], new_tree_ids) self.assertAllClose([2, 1], new_node_ids) self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
def _bt_model_fn( features, labels, mode, head, feature_columns, tree_hparams, n_batches_per_layer, config, closed_form_grad_and_hess_fn=None, example_id_column_name=None, # TODO(youngheek): replace this later using other options. train_in_memory=False, name='boosted_trees'): """Gradient Boosted Trees model_fn. Args: features: dict of `Tensor`. labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype `int32` or `int64` in the range `[0, n_classes)`. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. head: A `head_lib._Head` instance. feature_columns: Iterable of `feature_column._FeatureColumn` model inputs. tree_hparams: TODO. collections.namedtuple for hyper parameters. n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at least n_batches_per_layer accumulations. config: `RunConfig` object to configure the runtime settings. closed_form_grad_and_hess_fn: a function that accepts logits and labels and returns gradients and hessians. By default, they are created by tf.gradients() from the loss. example_id_column_name: Name of the feature for a unique ID per example. Currently experimental -- not exposed to public API. train_in_memory: `bool`, when true, it assumes the dataset is in memory, i.e., input_fn should return the entire dataset as a single batch, and also n_batches_per_layer should be set as 1. name: Name to use for the model. Returns: An `EstimatorSpec` instance. Raises: ValueError: mode or params are invalid, or features has the wrong type. """ is_single_machine = (config.num_worker_replicas <= 1) sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name) if train_in_memory: assert n_batches_per_layer == 1, ( 'When train_in_memory is enabled, input_fn should return the entire ' 'dataset as a single batch, and n_batches_per_layer should be set as ' '1.') if (not config.is_chief or config.num_worker_replicas > 1 or config.num_ps_replicas > 0): raise ValueError('train_in_memory is supported only for ' 'non-distributed training.') worker_device = control_flow_ops.no_op().device # maximum number of splits possible in the whole tree =2^(D-1)-1 # TODO(youngheek): perhaps storage could be optimized by storing stats with # the dimension max_splits_per_layer, instead of max_splits (for the entire # tree). max_splits = (1 << tree_hparams.max_depth) - 1 train_op = [] with ops.name_scope(name) as name: # Prepare. global_step = training_util.get_or_create_global_step() bucket_size_list, feature_ids_list = _group_features_by_num_buckets( sorted_feature_columns) # Extract input features and set up cache for training. training_state_cache = None if mode == model_fn.ModeKeys.TRAIN and train_in_memory: # cache transformed features as well for in-memory training. batch_size = array_ops.shape(labels)[0] input_feature_list, input_cache_op = (_cache_transformed_features( features, sorted_feature_columns, batch_size)) train_op.append(input_cache_op) training_state_cache = _CacheTrainingStatesUsingVariables( batch_size, head.logits_dimension) else: input_feature_list = _get_transformed_features( features, sorted_feature_columns) if mode == model_fn.ModeKeys.TRAIN and example_id_column_name: example_ids = features[example_id_column_name] training_state_cache = _CacheTrainingStatesUsingHashTable( example_ids, head.logits_dimension) # Create Ensemble resources. tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) # Create logits. if mode != model_fn.ModeKeys.TRAIN: logits = boosted_trees_ops.predict( # For non-TRAIN mode, ensemble doesn't change after initialization, # so no local copy is needed; using tree_ensemble directly. tree_ensemble_handle=tree_ensemble.resource_handle, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) else: if is_single_machine: local_tree_ensemble = tree_ensemble ensemble_reload = control_flow_ops.no_op() else: # Have a local copy of ensemble for the distributed setting. with ops.device(worker_device): local_tree_ensemble = boosted_trees_ops.TreeEnsemble( name=name + '_local', is_local=True) # TODO(soroush): Do partial updates if this becomes a bottleneck. ensemble_reload = local_tree_ensemble.deserialize( *tree_ensemble.serialize()) if training_state_cache: cached_tree_ids, cached_node_ids, cached_logits = ( training_state_cache.lookup()) else: # Always start from the beginning when no cache is set up. batch_size = array_ops.shape(labels)[0] cached_tree_ids, cached_node_ids, cached_logits = ( array_ops.zeros([batch_size], dtype=dtypes.int32), array_ops.zeros([batch_size], dtype=dtypes.int32), array_ops.zeros([batch_size, head.logits_dimension], dtype=dtypes.float32)) with ops.control_dependencies([ensemble_reload]): (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, last_layer_nodes_range) = local_tree_ensemble.get_states() summary.scalar('ensemble/num_trees', num_trees) summary.scalar('ensemble/num_finalized_trees', num_finalized_trees) summary.scalar('ensemble/num_attempted_layers', num_attempted_layers) partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict( tree_ensemble_handle=local_tree_ensemble.resource_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) logits = cached_logits + partial_logits # Create training graph. def _train_op_fn(loss): """Run one training iteration.""" if training_state_cache: train_op.append( training_state_cache.insert(tree_ids, node_ids, logits)) if closed_form_grad_and_hess_fn: gradients, hessians = closed_form_grad_and_hess_fn( logits, labels) else: gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0] hessians = gradients_impl.gradients(gradients, logits, name='Hessians')[0] stats_summaries_list = [] for i, feature_ids in enumerate(feature_ids_list): num_buckets = bucket_size_list[i] summaries = [ array_ops.squeeze(boosted_trees_ops.make_stats_summary( node_ids=node_ids, gradients=gradients, hessians=hessians, bucketized_features_list=[input_feature_list[f]], max_splits=max_splits, num_buckets=num_buckets), axis=0) for f in feature_ids ] stats_summaries_list.append(summaries) accumulators = [] def grow_tree_from_stats_summaries(stats_summaries_list, feature_ids_list): """Updates ensemble based on the best gains from stats summaries.""" node_ids_per_feature = [] gains_list = [] thresholds_list = [] left_node_contribs_list = [] right_node_contribs_list = [] all_feature_ids = [] assert len(stats_summaries_list) == len(feature_ids_list) for i, feature_ids in enumerate(feature_ids_list): (numeric_node_ids_per_feature, numeric_gains_list, numeric_thresholds_list, numeric_left_node_contribs_list, numeric_right_node_contribs_list) = ( boosted_trees_ops.calculate_best_gains_per_feature( node_id_range=last_layer_nodes_range, stats_summary_list=stats_summaries_list[i], l1=tree_hparams.l1, l2=tree_hparams.l2, tree_complexity=tree_hparams.tree_complexity, min_node_weight=tree_hparams.min_node_weight, max_splits=max_splits)) all_feature_ids += feature_ids node_ids_per_feature += numeric_node_ids_per_feature gains_list += numeric_gains_list thresholds_list += numeric_thresholds_list left_node_contribs_list += numeric_left_node_contribs_list right_node_contribs_list += numeric_right_node_contribs_list grow_op = boosted_trees_ops.update_ensemble( # Confirm if local_tree_ensemble or tree_ensemble should be used. tree_ensemble.resource_handle, feature_ids=all_feature_ids, node_ids=node_ids_per_feature, gains=gains_list, thresholds=thresholds_list, left_node_contribs=left_node_contribs_list, right_node_contribs=right_node_contribs_list, learning_rate=tree_hparams.learning_rate, max_depth=tree_hparams.max_depth, pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING) return grow_op if train_in_memory and is_single_machine: train_op.append(distribute_lib.increment_var(global_step)) train_op.append( grow_tree_from_stats_summaries(stats_summaries_list, feature_ids_list)) else: dependencies = [] for i, feature_ids in enumerate(feature_ids_list): stats_summaries = stats_summaries_list[i] accumulator = data_flow_ops.ConditionalAccumulator( dtype=dtypes.float32, # The stats consist of grads and hessians (the last dimension). shape=[ len(feature_ids), max_splits, bucket_size_list[i], 2 ], shared_name='numeric_stats_summary_accumulator_' + str(i)) accumulators.append(accumulator) apply_grad = accumulator.apply_grad( array_ops.stack(stats_summaries, axis=0), stamp_token) dependencies.append(apply_grad) def grow_tree_from_accumulated_summaries_fn(): """Updates the tree with the best layer from accumulated summaries.""" # Take out the accumulated summaries from the accumulator and grow. stats_summaries_list = [] stats_summaries_list = [ array_ops.unstack(accumulator.take_grad(1), axis=0) for accumulator in accumulators ] grow_op = grow_tree_from_stats_summaries( stats_summaries_list, feature_ids_list) return grow_op with ops.control_dependencies(dependencies): train_op.append(distribute_lib.increment_var(global_step)) if config.is_chief: min_accumulated = math_ops.reduce_min( array_ops.stack([ acc.num_accumulated() for acc in accumulators ])) train_op.append( control_flow_ops.cond( math_ops.greater_equal(min_accumulated, n_batches_per_layer), grow_tree_from_accumulated_summaries_fn, control_flow_ops.no_op, name='wait_until_n_batches_accumulated')) return control_flow_ops.group(train_op, name='train_op') estimator_spec = head.create_estimator_spec(features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) if mode == model_fn.ModeKeys.TRAIN: # Add an early stop hook. estimator_spec = estimator_spec._replace( training_hooks=estimator_spec.training_hooks + (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers, tree_hparams.n_trees, tree_hparams.max_depth), )) return estimator_spec
def testCachedPredictionTheWholeTreeWasPruned(self): """Tests that prediction based on previous node in the tree works.""" with self.cached_session() as session: tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() text_format.Merge(""" trees { nodes { leaf { scalar: 0.00 } } } tree_weights: 1.0 tree_metadata { num_layers_grown: 1 is_finalized: true post_pruned_nodes_meta { new_node_id: 0 logit_change: 0.0 } post_pruned_nodes_meta { new_node_id: 0 logit_change: -6.0 } post_pruned_nodes_meta { new_node_id: 0 logit_change: 5.0 } } growing_metadata { num_trees_attempted: 1 num_layers_attempted: 1 } """, tree_ensemble_config) # Create existing ensemble. tree_ensemble = boosted_trees_ops.TreeEnsemble( 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) tree_ensemble_handle = tree_ensemble.resource_handle resources.initialize_resources(resources.shared_resources()).run() cached_tree_ids = [ 0, 0, ] # The predictions were cached in 1 and 2, both were pruned to the root. cached_node_ids = [1, 2] # We have two features: 0 and 1.These are not going to be used anywhere. feature_0_values = [12, 17] feature_1_values = [12, 12] # Grow tree ensemble. predict_op = boosted_trees_ops.training_predict( tree_ensemble_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=[feature_0_values, feature_1_values], logits_dimension=1) logits_updates, new_tree_ids, new_node_ids = session.run(predict_op) # We are in the last tree. self.assertAllClose([0, 0], new_tree_ids) self.assertAllClose([0, 0], new_node_ids) self.assertAllClose([[-6.0], [5.0]], logits_updates)
def _bt_model_fn( features, labels, mode, head, feature_columns, tree_hparams, n_batches_per_layer, config, closed_form_grad_and_hess_fn=None, example_id_column_name=None, # TODO(youngheek): replace this later using other options. train_in_memory=False, name='TreeEnsembleModel'): """Gradient Boosted Decision Tree model_fn. Args: features: dict of `Tensor`. labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype `int32` or `int64` in the range `[0, n_classes)`. mode: Defines whether this is training, evaluation or prediction. See `ModeKeys`. head: A `head_lib._Head` instance. feature_columns: Iterable of `feature_column._FeatureColumn` model inputs. tree_hparams: TODO. collections.namedtuple for hyper parameters. n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at least n_batches_per_layer accumulations. config: `RunConfig` object to configure the runtime settings. closed_form_grad_and_hess_fn: a function that accepts logits and labels and returns gradients and hessians. By default, they are created by tf.gradients() from the loss. example_id_column_name: Name of the feature for a unique ID per example. Currently experimental -- not exposed to public API. train_in_memory: `bool`, when true, it assumes the dataset is in memory, i.e., input_fn should return the entire dataset as a single batch, and also n_batches_per_layer should be set as 1. name: Name to use for the model. Returns: An `EstimatorSpec` instance. Raises: ValueError: mode or params are invalid, or features has the wrong type. """ is_single_machine = (config.num_worker_replicas == 1) if train_in_memory: assert n_batches_per_layer == 1, ( 'When train_in_memory is enabled, input_fn should return the entire ' 'dataset as a single batch, and n_batches_per_layer should be set as ' '1.') worker_device = control_flow_ops.no_op().device # maximum number of splits possible in the whole tree =2^(D-1)-1 # TODO(youngheek): perhaps storage could be optimized by storing stats with # the dimension max_splits_per_layer, instead of max_splits (for the entire # tree). max_splits = (1 << tree_hparams.max_depth) - 1 with ops.name_scope(name) as name: # Prepare. global_step = training_util.get_or_create_global_step() input_feature_list, num_buckets = _get_transformed_features( features, feature_columns) if train_in_memory and mode == model_fn.ModeKeys.TRAIN: input_feature_list = [ _keep_as_local_variable(feature) for feature in input_feature_list ] num_features = len(input_feature_list) cache = None if mode == model_fn.ModeKeys.TRAIN: if train_in_memory and is_single_machine: # maybe just train_in_memory? batch_size = array_ops.shape(input_feature_list[0])[0] cache = _CacheTrainingStatesUsingVariables( batch_size, head.logits_dimension) elif example_id_column_name: example_ids = features[example_id_column_name] cache = _CacheTrainingStatesUsingHashTable( example_ids, head.logits_dimension) # Create Ensemble resources. if is_single_machine: tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) local_tree_ensemble = tree_ensemble ensemble_reload = control_flow_ops.no_op() else: tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) with ops.device(worker_device): local_tree_ensemble = boosted_trees_ops.TreeEnsemble( name=name + '_local', is_local=True) # TODO(soroush): Do partial updates if this becomes a bottleneck. ensemble_reload = local_tree_ensemble.deserialize( *tree_ensemble.serialize()) # Create logits. if mode != model_fn.ModeKeys.TRAIN: logits = boosted_trees_ops.predict( tree_ensemble_handle=local_tree_ensemble.resource_handle, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension, max_depth=tree_hparams.max_depth) else: if cache: cached_tree_ids, cached_node_ids, cached_logits = cache.lookup( ) else: # Always start from the beginning when no cache is set up. batch_size = array_ops.shape(input_feature_list[0])[0] cached_tree_ids, cached_node_ids, cached_logits = ( array_ops.zeros([batch_size], dtype=dtypes.int32), array_ops.zeros([batch_size], dtype=dtypes.int32), array_ops.zeros([batch_size, head.logits_dimension], dtype=dtypes.float32)) with ops.control_dependencies([ensemble_reload]): (stamp_token, num_trees, num_finalized_trees, num_attempted_layers) = local_tree_ensemble.get_states() summary.scalar('ensemble/num_trees', num_trees) summary.scalar('ensemble/num_finalized_trees', num_finalized_trees) summary.scalar('ensemble/num_attempted_layers', num_attempted_layers) partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict( tree_ensemble_handle=local_tree_ensemble.resource_handle, cached_tree_ids=cached_tree_ids, cached_node_ids=cached_node_ids, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension, max_depth=tree_hparams.max_depth) logits = cached_logits + partial_logits # Create training graph. def _train_op_fn(loss): """Run one training iteration.""" train_op = [] if cache: train_op.append(cache.insert(tree_ids, node_ids, logits)) if closed_form_grad_and_hess_fn: gradients, hessians = closed_form_grad_and_hess_fn( logits, labels) else: gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0] hessians = gradients_impl.gradients(gradients, logits, name='Hessians')[0] stats_summary_list = [ array_ops.squeeze(boosted_trees_ops.make_stats_summary( node_ids=node_ids, gradients=gradients, hessians=hessians, bucketized_features_list=[input_feature_list[f]], max_splits=max_splits, num_buckets=num_buckets), axis=0) for f in range(num_features) ] def grow_tree_from_stats_summaries(stats_summary_list): """Updates ensemble based on the best gains from stats summaries.""" (node_ids_per_feature, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list) = ( boosted_trees_ops.calculate_best_gains_per_feature( node_id_range=array_ops.stack([ math_ops.reduce_min(node_ids), math_ops.reduce_max(node_ids) ]), stats_summary_list=stats_summary_list, l1=tree_hparams.l1, l2=tree_hparams.l2, tree_complexity=tree_hparams.tree_complexity, max_splits=max_splits)) grow_op = boosted_trees_ops.update_ensemble( # Confirm if local_tree_ensemble or tree_ensemble should be used. tree_ensemble.resource_handle, feature_ids=math_ops.range(0, num_features, dtype=dtypes.int32), node_ids=node_ids_per_feature, gains=gains_list, thresholds=thresholds_list, left_node_contribs=left_node_contribs_list, right_node_contribs=right_node_contribs_list, learning_rate=tree_hparams.learning_rate, max_depth=tree_hparams.max_depth, pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING) return grow_op if train_in_memory and is_single_machine: train_op.append(state_ops.assign_add(global_step, 1)) train_op.append( grow_tree_from_stats_summaries(stats_summary_list)) else: summary_accumulator = data_flow_ops.ConditionalAccumulator( dtype=dtypes.float32, # The stats consist of gradients and hessians (the last dimension). shape=[num_features, max_splits, num_buckets, 2], shared_name='stats_summary_accumulator') apply_grad = summary_accumulator.apply_grad( array_ops.stack(stats_summary_list, axis=0), stamp_token) def grow_tree_from_accumulated_summaries_fn(): """Updates the tree with the best layer from accumulated summaries.""" # Take out the accumulated summaries from the accumulator and grow. stats_summary_list = array_ops.unstack( summary_accumulator.take_grad(1), axis=0) grow_op = grow_tree_from_stats_summaries( stats_summary_list) return grow_op with ops.control_dependencies([apply_grad]): train_op.append(state_ops.assign_add(global_step, 1)) if config.is_chief: train_op.append( control_flow_ops.cond( math_ops.greater_equal( summary_accumulator.num_accumulated(), n_batches_per_layer), grow_tree_from_accumulated_summaries_fn, control_flow_ops.no_op, name='wait_until_n_batches_accumulated')) return control_flow_ops.group(train_op, name='train_op') estimator_spec = head.create_estimator_spec(features=features, mode=mode, labels=labels, train_op_fn=_train_op_fn, logits=logits) if mode == model_fn.ModeKeys.TRAIN: # Add an early stop hook. estimator_spec = estimator_spec._replace( training_hooks=estimator_spec.training_hooks + (StopAtNumTreesHook(num_trees, tree_hparams.n_trees), )) return estimator_spec