def __init__(self, resource_handle, create_op, name): """Creates a _TreeEnsembleSavable object. Args: resource_handle: handle to the decision tree ensemble variable. create_op: the op to initialize the variable. name: the name to save the tree ensemble variable under. """ stamp_token, serialized = ( gen_boosted_trees_ops.boosted_trees_serialize_ensemble( resource_handle)) # slice_spec is useful for saving a slice from a variable. # It's not meaningful the tree ensemble variable. So we just pass an empty # value. slice_spec = '' specs = [ saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, name + '_stamp'), saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec, name + '_serialized'), ] super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name) self._resource_handle = resource_handle self._create_op = create_op
def _get_train_op_and_ensemble(self, head, config, is_classification, train_in_memory): """Calls bt_model_fn() and returns the train_op and ensemble_serialzed.""" features, labels = _make_train_input_fn(is_classification)() estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access features=features, labels=labels, mode=model_fn.ModeKeys.TRAIN, head=head, feature_columns=self._feature_columns, tree_hparams=self._tree_hparams, example_id_column_name=EXAMPLE_ID_COLUMN, n_batches_per_layer=1, config=config, train_in_memory=train_in_memory) resources.initialize_resources(resources.shared_resources()).run() variables.global_variables_initializer().run() variables.local_variables_initializer().run() # Gets the train_op and serialized proto of the ensemble. shared_resources = resources.shared_resources() self.assertEqual(1, len(shared_resources)) train_op = estimator_spec.train_op with ops.control_dependencies([train_op]): _, ensemble_serialized = ( gen_boosted_trees_ops.boosted_trees_serialize_ensemble( shared_resources[0].handle)) return train_op, ensemble_serialized
def _get_train_op_and_ensemble(self, head, config, is_classification, train_in_memory): """Calls bt_model_fn() and returns the train_op and ensemble_serialzed.""" features, labels = _make_train_input_fn(is_classification)() estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access features=features, labels=labels, mode=model_fn.ModeKeys.TRAIN, head=head, feature_columns=self._feature_columns, tree_hparams=self._tree_hparams, example_id_column_name=EXAMPLE_ID_COLUMN, n_batches_per_layer=1, config=config, train_in_memory=train_in_memory) resources.initialize_resources(resources.shared_resources()).run() variables.global_variables_initializer().run() variables.local_variables_initializer().run() # Gets the train_op and serialized proto of the ensemble. shared_resources = resources.shared_resources() self.assertEqual(1, len(shared_resources)) train_op = estimator_spec.train_op with ops.control_dependencies([train_op]): _, ensemble_serialized = ( gen_boosted_trees_ops.boosted_trees_serialize_ensemble( shared_resources[0].handle)) return train_op, ensemble_serialized
def serialize(self): """Serializes the ensemble into proto and returns the serialized proto. Returns: stamp_token: int64 scalar Tensor to denote the stamp of the resource. serialized_proto: string scalar Tensor of the serialized proto. """ return gen_boosted_trees_ops.boosted_trees_serialize_ensemble( self.resource_handle)
def serialize(self): """Serializes the ensemble into proto and returns the serialized proto. Returns: stamp_token: int64 scalar Tensor to denote the stamp of the resource. serialized_proto: string scalar Tensor of the serialized proto. """ return gen_boosted_trees_ops.boosted_trees_serialize_ensemble( self.resource_handle)
def __init__(self, resource_handle, create_op, name): """Creates a _TreeEnsembleSavable object. Args: resource_handle: handle to the decision tree ensemble variable. create_op: the op to initialize the variable. name: the name to save the tree ensemble variable under. """ stamp_token, serialized = ( gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle)) # slice_spec is useful for saving a slice from a variable. # It's not meaningful the tree ensemble variable. So we just pass an empty # value. slice_spec = '' specs = [ saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, name + '_stamp'), saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec, name + '_serialized'), ] super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name) self._resource_handle = resource_handle self._create_op = create_op