def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. Args: ensemble_handle: ensemble resource handle. ensemble_stamp: stamp of ensemble resource. mode: learn.ModeKeys.TRAIN or EVAL or INFER. Returns: a dictionary of prediction results - ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS, NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPED. """ ensemble_stats = training_ops.tree_ensemble_stats( ensemble_handle, ensemble_stamp) num_handlers = (len(self._dense_floats) + len(self._sparse_float_shapes) + len(self._sparse_int_shapes)) # Used during feature selection. used_handlers = model_ops.tree_ensemble_used_handlers( ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) # We don't need dropout info - we can always restore it based on the # seed. apply_dropout, seed = _dropout_params(mode, ensemble_stats) # Make sure ensemble stats run. This will check that the ensemble has # the right stamp. with ops.control_dependencies(ensemble_stats): predictions, _ = prediction_ops.gradient_trees_prediction( ensemble_handle, seed, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, learner_config=self._learner_config_serialized, apply_dropout=apply_dropout, apply_averaging=mode != learn.ModeKeys.TRAIN, use_locking=True, center_bias=self._center_bias, reduce_dim=self._reduce_dim) partition_ids = prediction_ops.gradient_trees_partition_examples( ensemble_handle, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, ensemble_stats, used_handlers)
def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. Args: ensemble_handle: ensemble resource handle. ensemble_stamp: stamp of ensemble resource. mode: learn.ModeKeys.TRAIN or EVAL or INFER. Returns: a dictionary of prediction results - ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS, NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPED. """ ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) num_handlers = ( len(self._dense_floats) + len(self._sparse_float_shapes) + len(self._sparse_int_shapes)) # Used during feature selection. used_handlers = model_ops.tree_ensemble_used_handlers( ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) # We don't need dropout info - we can always restore it based on the # seed. apply_dropout, seed = _dropout_params(mode, ensemble_stats) # Make sure ensemble stats run. This will check that the ensemble has # the right stamp. with ops.control_dependencies(ensemble_stats): predictions, _ = prediction_ops.gradient_trees_prediction( ensemble_handle, seed, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, learner_config=self._learner_config_serialized, apply_dropout=apply_dropout, apply_averaging=mode != learn.ModeKeys.TRAIN, use_locking=True, center_bias=self._center_bias, reduce_dim=self._reduce_dim) partition_ids = prediction_ops.gradient_trees_partition_examples( ensemble_handle, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, ensemble_stats, used_handlers)
def _update_ensemble(): """A method to update the tree ensemble.""" # Get next stamp token. next_ensemble_stamp = ensemble_stamp + 1 # Finalize bias stats. _, _, _, bias_grads, bias_hess = bias_stats_accumulator.flush( ensemble_stamp, next_ensemble_stamp) # Finalize handler splits. are_splits_ready_list = [] partition_ids_list = [] gains_list = [] split_info_list = [] for handler in handlers: (are_splits_ready, partition_ids, gains, split_info) = handler.make_splits(ensemble_stamp, next_ensemble_stamp, class_id) are_splits_ready_list.append(are_splits_ready) partition_ids_list.append(partition_ids) gains_list.append(gains) split_info_list.append(split_info) # Stack all the inputs to one tensor per type. # This is a workaround for the slowness of graph building in tf.cond. # See (b/36554864). split_sizes = array_ops.stack([ array_ops.shape(partition_id)[0] for partition_id in partition_ids_list ]) partition_ids = array_ops.concat(partition_ids_list, axis=0) gains = array_ops.concat(gains_list, axis=0) split_infos = array_ops.concat(split_info_list, axis=0) # Determine if all splits are ready. are_all_splits_ready = math_ops.reduce_all( array_ops.stack(are_splits_ready_list, axis=0, name="stack_handler_readiness")) # Define bias centering update operation. def _center_bias_fn(): # Center tree ensemble bias. delta_updates = array_ops.where( bias_hess > 0, -bias_grads / bias_hess, array_ops.zeros_like(bias_grads)) center_bias = training_ops.center_tree_ensemble_bias( tree_ensemble_handle=self._ensemble_handle, stamp_token=ensemble_stamp, next_stamp_token=next_ensemble_stamp, delta_updates=delta_updates, learner_config=self._learner_config_serialized) return continue_centering.assign(center_bias) # Define ensemble growing operations. def _grow_ensemble_ready_fn(): # Grow the ensemble given the current candidates. sizes = array_ops.unstack(split_sizes) partition_ids_list = list( array_ops.split(partition_ids, sizes, axis=0)) gains_list = list(array_ops.split(gains, sizes, axis=0)) split_info_list = list( array_ops.split(split_infos, sizes, axis=0)) return training_ops.grow_tree_ensemble( tree_ensemble_handle=self._ensemble_handle, stamp_token=ensemble_stamp, next_stamp_token=next_ensemble_stamp, learning_rate=learning_rate, partition_ids=partition_ids_list, gains=gains_list, splits=split_info_list, learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, center_bias=self._center_bias) def _grow_ensemble_not_ready_fn(): # Don't grow the ensemble, just update the stamp. return training_ops.grow_tree_ensemble( tree_ensemble_handle=self._ensemble_handle, stamp_token=ensemble_stamp, next_stamp_token=next_ensemble_stamp, learning_rate=0, partition_ids=[], gains=[], splits=[], learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, center_bias=self._center_bias) def _grow_ensemble_fn(): # Conditionally grow an ensemble depending on whether the splits # from all the handlers are ready. return control_flow_ops.cond(are_all_splits_ready, _grow_ensemble_ready_fn, _grow_ensemble_not_ready_fn) # Update ensemble. update_ops = [are_all_splits_ready] update_model = control_flow_ops.cond(continue_centering, _center_bias_fn, _grow_ensemble_fn) update_ops.append(update_model) # Update ensemble stats. with ops.control_dependencies([update_model]): stats = training_ops.tree_ensemble_stats( self._ensemble_handle, stamp_token=next_ensemble_stamp) update_ops.append(self._finalized_trees.assign( stats.num_trees)) update_ops.append( self._attempted_trees.assign(stats.attempted_trees)) update_ops.append(num_layers.assign(stats.num_layers)) update_ops.append(active_tree.assign(stats.active_tree)) update_ops.append(active_layer.assign(stats.active_layer)) # Flush step stats. update_ops.extend( steps_accumulator.flush(ensemble_stamp, next_ensemble_stamp)) return control_flow_ops.group(*update_ops, name="update_ensemble")
def _update_ensemble(): """A method to update the tree ensemble.""" # Get next stamp token. next_ensemble_stamp = ensemble_stamp + 1 # Finalize bias stats. _, _, _, bias_grads, bias_hess = bias_stats_accumulator.flush( ensemble_stamp, next_ensemble_stamp) # Finalize handler splits. are_splits_ready_list = [] partition_ids_list = [] gains_list = [] split_info_list = [] for handler in handlers: (are_splits_ready, partition_ids, gains, split_info) = handler.make_splits( ensemble_stamp, next_ensemble_stamp, class_id) are_splits_ready_list.append(are_splits_ready) partition_ids_list.append(partition_ids) gains_list.append(gains) split_info_list.append(split_info) # Stack all the inputs to one tensor per type. # This is a workaround for the slowness of graph building in tf.cond. # See (b/36554864). split_sizes = array_ops.stack([ array_ops.shape(partition_id)[0] for partition_id in partition_ids_list ]) partition_ids = array_ops.concat(partition_ids_list, axis=0) gains = array_ops.concat(gains_list, axis=0) split_infos = array_ops.concat(split_info_list, axis=0) # Determine if all splits are ready. are_all_splits_ready = math_ops.reduce_all( array_ops.stack( are_splits_ready_list, axis=0, name="stack_handler_readiness")) # Define bias centering update operation. def _center_bias_fn(): # Center tree ensemble bias. delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess, array_ops.zeros_like(bias_grads)) center_bias = training_ops.center_tree_ensemble_bias( tree_ensemble_handle=self._ensemble_handle, stamp_token=ensemble_stamp, next_stamp_token=next_ensemble_stamp, delta_updates=delta_updates, learner_config=self._learner_config_serialized) return continue_centering.assign(center_bias) # Define ensemble growing operations. def _grow_ensemble_ready_fn(): # Grow the ensemble given the current candidates. sizes = array_ops.unstack(split_sizes) partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0)) gains_list = list(array_ops.split(gains, sizes, axis=0)) split_info_list = list(array_ops.split(split_infos, sizes, axis=0)) return training_ops.grow_tree_ensemble( tree_ensemble_handle=self._ensemble_handle, stamp_token=ensemble_stamp, next_stamp_token=next_ensemble_stamp, learning_rate=learning_rate, partition_ids=partition_ids_list, gains=gains_list, splits=split_info_list, learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, center_bias=self._center_bias) def _grow_ensemble_not_ready_fn(): # Don't grow the ensemble, just update the stamp. return training_ops.grow_tree_ensemble( tree_ensemble_handle=self._ensemble_handle, stamp_token=ensemble_stamp, next_stamp_token=next_ensemble_stamp, learning_rate=0, partition_ids=[], gains=[], splits=[], learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, center_bias=self._center_bias) def _grow_ensemble_fn(): # Conditionally grow an ensemble depending on whether the splits # from all the handlers are ready. return control_flow_ops.cond(are_all_splits_ready, _grow_ensemble_ready_fn, _grow_ensemble_not_ready_fn) # Update ensemble. update_ops = [are_all_splits_ready] update_model = control_flow_ops.cond(continue_centering, _center_bias_fn, _grow_ensemble_fn) update_ops.append(update_model) # Update ensemble stats. with ops.control_dependencies([update_model]): stats = training_ops.tree_ensemble_stats( self._ensemble_handle, stamp_token=next_ensemble_stamp) update_ops.append(self._finalized_trees.assign(stats.num_trees)) update_ops.append(self._attempted_trees.assign(stats.attempted_trees)) update_ops.append(num_layers.assign(stats.num_layers)) update_ops.append(active_tree.assign(stats.active_tree)) update_ops.append(active_layer.assign(stats.active_layer)) # Flush step stats. update_ops.extend( steps_accumulator.flush(ensemble_stamp, next_ensemble_stamp)) return control_flow_ops.group(*update_ops, name="update_ensemble")
def predict(self, mode): """Returns predictions given the features and mode. Args: mode: Mode the graph is running in (train|predict|eval). Returns: A dict of predictions tensors. Raises: ValueError: if features is not valid. """ apply_averaging = mode != learn.ModeKeys.TRAIN # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. input_deps = (self._dense_floats + self._sparse_float_indices + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") if any(i.device != input_deps[0].device for i in input_deps): raise ValueError("All input tensors should be on the same device.") # Get most current model stamp. ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle) # Determine if ensemble is colocated with the inputs. if self._ensemble_handle.device != input_deps[0].device: # Create a local ensemble and get its local stamp. with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name: local_ensemble_handle = ( gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name)) create_op = gen_model_ops.create_tree_ensemble_variable( local_ensemble_handle, stamp_token=-1, tree_ensemble_config="") with ops.control_dependencies([create_op]): local_stamp = model_ops.tree_ensemble_stamp_token( local_ensemble_handle) # Determine whether the local ensemble is stale and update it if needed. def _refresh_local_ensemble_fn(): # Serialize the model from parameter server after reading all inputs. with ops.control_dependencies(input_deps): (ensemble_stamp, serialized_model) = ( model_ops.tree_ensemble_serialize(self._ensemble_handle)) # Update local ensemble with the serialized model from parameter server. with ops.control_dependencies([create_op]): return model_ops.tree_ensemble_deserialize( local_ensemble_handle, stamp_token=ensemble_stamp, tree_ensemble_config=serialized_model), ensemble_stamp refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond( math_ops.not_equal(ensemble_stamp, local_stamp), _refresh_local_ensemble_fn, lambda: (control_flow_ops.no_op(), ensemble_stamp)) # Once updated, Use the the local model for prediction. with ops.control_dependencies([refresh_local_ensemble]): ensemble_stats = training_ops.tree_ensemble_stats( local_ensemble_handle, ensemble_stamp) apply_dropout, seed = _dropout_params(mode, ensemble_stats) # We don't need dropout info - we can always restore it based on the # seed. predictions, predictions_no_dropout, _ = ( prediction_ops.gradient_trees_prediction( local_ensemble_handle, seed, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, learner_config=self._learner_config_serialized, apply_dropout=apply_dropout, apply_averaging=apply_averaging, use_locking=False, center_bias=self._center_bias, reduce_dim=self._reduce_dim)) partition_ids = prediction_ops.gradient_trees_partition_examples( local_ensemble_handle, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, use_locking=False) else: with ops.device(self._ensemble_handle.device): ensemble_stats = training_ops.tree_ensemble_stats( self._ensemble_handle, ensemble_stamp) apply_dropout, seed = _dropout_params(mode, ensemble_stats) # We don't need dropout info - we can always restore it based on the # seed. predictions, predictions_no_dropout, _ = ( prediction_ops.gradient_trees_prediction( self._ensemble_handle, seed, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, learner_config=self._learner_config_serialized, apply_dropout=apply_dropout, apply_averaging=apply_averaging, use_locking=False, center_bias=self._center_bias, reduce_dim=self._reduce_dim)) partition_ids = prediction_ops.gradient_trees_partition_examples( self._ensemble_handle, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, use_locking=False) return _make_predictions_dict(ensemble_stamp, predictions, predictions_no_dropout, partition_ids, ensemble_stats)