def testCreate(self): with self.test_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree = tree_ensemble_config.trees.add() _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) tree_ensemble_config.tree_weights.append(1.0) # Prepare learner config. learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=3, tree_ensemble_config=tree_ensemble_config.SerializeToString(), name="create_tree") resources.initialize_resources(resources.shared_resources()).run() result, _, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [self._sparse_float_indices1, self._sparse_float_indices2], [self._sparse_float_values1, self._sparse_float_values2], [self._sparse_float_shape1, self._sparse_float_shape2], [self._sparse_int_indices1], [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, center_bias=False, reduce_dim=True) self.assertAllClose(result.eval(), [[-0.4], [-0.4]]) stamp_token = model_ops.tree_ensemble_stamp_token( tree_ensemble_handle) self.assertEqual(stamp_token.eval(), 3)
def testCreate(self): with self.cached_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree = tree_ensemble_config.trees.add() _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) tree_ensemble_config.tree_weights.append(1.0) # Prepare learner config. learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=3, tree_ensemble_config=tree_ensemble_config.SerializeToString(), name="create_tree") resources.initialize_resources(resources.shared_resources()).run() result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 ], [self._sparse_float_values1, self._sparse_float_values2], [self._sparse_float_shape1, self._sparse_float_shape2], [self._sparse_int_indices1], [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, center_bias=False, reduce_dim=True) self.assertAllClose(result.eval(), [[-0.4], [-0.4]]) stamp_token = model_ops.tree_ensemble_stamp_token(tree_ensemble_handle) self.assertEqual(stamp_token.eval(), 3)
def predict(self, mode): """Returns predictions given the features and mode. Args: mode: Mode the graph is running in (train|predict|eval). Returns: A dict of predictions tensors. Raises: ValueError: if features is not valid. """ # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. input_deps = (self._dense_floats + self._sparse_float_indices + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") if any(i.device != input_deps[0].device for i in input_deps): raise ValueError("All input tensors should be on the same device.") # Get most current model stamp. ensemble_stamp = model_ops.tree_ensemble_stamp_token( self._ensemble_handle) # Determine if ensemble is colocated with the inputs. if self._ensemble_handle.device != input_deps[0].device: # Create a local ensemble and get its local stamp. with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name: local_ensemble_handle = ( gen_model_ops.decision_tree_ensemble_resource_handle_op( name=name)) create_op = gen_model_ops.create_tree_ensemble_variable( local_ensemble_handle, stamp_token=-1, tree_ensemble_config="") with ops.control_dependencies([create_op]): local_stamp = model_ops.tree_ensemble_stamp_token( local_ensemble_handle) # Determine whether the local ensemble is stale and update it if needed. def _refresh_local_ensemble_fn(): # Serialize the model from parameter server after reading all inputs. with ops.control_dependencies(input_deps): (ensemble_stamp, serialized_model) = (model_ops.tree_ensemble_serialize( self._ensemble_handle)) # Update local ensemble with the serialized model from parameter server. with ops.control_dependencies([create_op]): return model_ops.tree_ensemble_deserialize( local_ensemble_handle, stamp_token=ensemble_stamp, tree_ensemble_config=serialized_model), ensemble_stamp refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond( math_ops.not_equal(ensemble_stamp, local_stamp), _refresh_local_ensemble_fn, lambda: (control_flow_ops.no_op(), ensemble_stamp)) # Once updated, use the local model for prediction. with ops.control_dependencies([refresh_local_ensemble]): return self._predict_and_return_dict(local_ensemble_handle, ensemble_stamp, mode) else: # Use ensemble_handle directly, if colocated. with ops.device(self._ensemble_handle.device): return self._predict_and_return_dict(self._ensemble_handle, ensemble_stamp, mode)
def predict(self, mode): """Returns predictions given the features and mode. Args: mode: Mode the graph is running in (train|predict|eval). Returns: A dict of predictions tensors. Raises: ValueError: if features is not valid. """ apply_averaging = mode != learn.ModeKeys.TRAIN # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. input_deps = (self._dense_floats + self._sparse_float_indices + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") if any(i.device != input_deps[0].device for i in input_deps): raise ValueError("All input tensors should be on the same device.") # Get most current model stamp. ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle) # Determine if ensemble is colocated with the inputs. if self._ensemble_handle.device != input_deps[0].device: # Create a local ensemble and get its local stamp. with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name: local_ensemble_handle = ( gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name)) create_op = gen_model_ops.create_tree_ensemble_variable( local_ensemble_handle, stamp_token=-1, tree_ensemble_config="") with ops.control_dependencies([create_op]): local_stamp = model_ops.tree_ensemble_stamp_token( local_ensemble_handle) # Determine whether the local ensemble is stale and update it if needed. def _refresh_local_ensemble_fn(): # Serialize the model from parameter server after reading all inputs. with ops.control_dependencies(input_deps): (ensemble_stamp, serialized_model) = ( model_ops.tree_ensemble_serialize(self._ensemble_handle)) # Update local ensemble with the serialized model from parameter server. with ops.control_dependencies([create_op]): return model_ops.tree_ensemble_deserialize( local_ensemble_handle, stamp_token=ensemble_stamp, tree_ensemble_config=serialized_model), ensemble_stamp refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond( math_ops.not_equal(ensemble_stamp, local_stamp), _refresh_local_ensemble_fn, lambda: (control_flow_ops.no_op(), ensemble_stamp)) # Once updated, Use the the local model for prediction. with ops.control_dependencies([refresh_local_ensemble]): ensemble_stats = training_ops.tree_ensemble_stats( local_ensemble_handle, ensemble_stamp) apply_dropout, seed = _dropout_params(mode, ensemble_stats) # We don't need dropout info - we can always restore it based on the # seed. predictions, predictions_no_dropout, _ = ( prediction_ops.gradient_trees_prediction( local_ensemble_handle, seed, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, learner_config=self._learner_config_serialized, apply_dropout=apply_dropout, apply_averaging=apply_averaging, use_locking=False, center_bias=self._center_bias, reduce_dim=self._reduce_dim)) partition_ids = prediction_ops.gradient_trees_partition_examples( local_ensemble_handle, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, use_locking=False) else: with ops.device(self._ensemble_handle.device): ensemble_stats = training_ops.tree_ensemble_stats( self._ensemble_handle, ensemble_stamp) apply_dropout, seed = _dropout_params(mode, ensemble_stats) # We don't need dropout info - we can always restore it based on the # seed. predictions, predictions_no_dropout, _ = ( prediction_ops.gradient_trees_prediction( self._ensemble_handle, seed, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, learner_config=self._learner_config_serialized, apply_dropout=apply_dropout, apply_averaging=apply_averaging, use_locking=False, center_bias=self._center_bias, reduce_dim=self._reduce_dim)) partition_ids = prediction_ops.gradient_trees_partition_examples( self._ensemble_handle, self._dense_floats, self._sparse_float_indices, self._sparse_float_values, self._sparse_float_shapes, self._sparse_int_indices, self._sparse_int_values, self._sparse_int_shapes, use_locking=False) return _make_predictions_dict(ensemble_stamp, predictions, predictions_no_dropout, partition_ids, ensemble_stats)
def predict(self, mode): """Returns predictions given the features and mode. Args: mode: Mode the graph is running in (train|predict|eval). Returns: A dict of predictions tensors. Raises: ValueError: if features is not valid. """ # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. input_deps = (self._dense_floats + self._sparse_float_indices + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") if any(i.device != input_deps[0].device for i in input_deps): raise ValueError("All input tensors should be on the same device.") # Get most current model stamp. ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle) # Determine if ensemble is colocated with the inputs. if self._ensemble_handle.device != input_deps[0].device: # Create a local ensemble and get its local stamp. with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name: local_ensemble_handle = ( gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name)) create_op = gen_model_ops.create_tree_ensemble_variable( local_ensemble_handle, stamp_token=-1, tree_ensemble_config="") with ops.control_dependencies([create_op]): local_stamp = model_ops.tree_ensemble_stamp_token( local_ensemble_handle) # Determine whether the local ensemble is stale and update it if needed. def _refresh_local_ensemble_fn(): # Serialize the model from parameter server after reading all inputs. with ops.control_dependencies(input_deps): (ensemble_stamp, serialized_model) = ( model_ops.tree_ensemble_serialize(self._ensemble_handle)) # Update local ensemble with the serialized model from parameter server. with ops.control_dependencies([create_op]): return model_ops.tree_ensemble_deserialize( local_ensemble_handle, stamp_token=ensemble_stamp, tree_ensemble_config=serialized_model), ensemble_stamp refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond( math_ops.not_equal(ensemble_stamp, local_stamp), _refresh_local_ensemble_fn, lambda: (control_flow_ops.no_op(), ensemble_stamp)) # Once updated, use the local model for prediction. with ops.control_dependencies([refresh_local_ensemble]): return self._predict_and_return_dict(local_ensemble_handle, ensemble_stamp, mode) else: # Use ensemble_handle directly, if colocated. with ops.device(self._ensemble_handle.device): return self._predict_and_return_dict(self._ensemble_handle, ensemble_stamp, mode)