Beispiel #1
0
    def __init__(self, resource_handle, create_op, name):
        """Creates a _TreeEnsembleSavable object.

    Args:
      resource_handle: handle to the decision tree ensemble variable.
      create_op: the op to initialize the variable.
      name: the name to save the tree ensemble variable under.
    """
        stamp_token, serialized = (
            gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
                resource_handle))
        # slice_spec is useful for saving a slice from a variable.
        # It's not meaningful the tree ensemble variable. So we just pass an empty
        # value.
        slice_spec = ''
        specs = [
            saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec,
                                            name + '_stamp'),
            saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec,
                                            name + '_serialized'),
        ]
        super(_TreeEnsembleSavable, self).__init__(resource_handle, specs,
                                                   name)
        self._resource_handle = resource_handle
        self._create_op = create_op
  def _get_train_op_and_ensemble(self, head, config, is_classification,
                                 train_in_memory):
    """Calls bt_model_fn() and returns the train_op and ensemble_serialzed."""
    features, labels = _make_train_input_fn(is_classification)()
    estimator_spec = boosted_trees._bt_model_fn(  # pylint:disable=protected-access
        features=features,
        labels=labels,
        mode=model_fn.ModeKeys.TRAIN,
        head=head,
        feature_columns=self._feature_columns,
        tree_hparams=self._tree_hparams,
        example_id_column_name=EXAMPLE_ID_COLUMN,
        n_batches_per_layer=1,
        config=config,
        train_in_memory=train_in_memory)
    resources.initialize_resources(resources.shared_resources()).run()
    variables.global_variables_initializer().run()
    variables.local_variables_initializer().run()

    # Gets the train_op and serialized proto of the ensemble.
    shared_resources = resources.shared_resources()
    self.assertEqual(1, len(shared_resources))
    train_op = estimator_spec.train_op
    with ops.control_dependencies([train_op]):
      _, ensemble_serialized = (
          gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
              shared_resources[0].handle))
    return train_op, ensemble_serialized
    def _get_train_op_and_ensemble(self, head, config, is_classification,
                                   train_in_memory):
        """Calls bt_model_fn() and returns the train_op and ensemble_serialzed."""
        features, labels = _make_train_input_fn(is_classification)()
        estimator_spec = boosted_trees._bt_model_fn(  # pylint:disable=protected-access
            features=features,
            labels=labels,
            mode=model_fn.ModeKeys.TRAIN,
            head=head,
            feature_columns=self._feature_columns,
            tree_hparams=self._tree_hparams,
            example_id_column_name=EXAMPLE_ID_COLUMN,
            n_batches_per_layer=1,
            config=config,
            train_in_memory=train_in_memory)
        resources.initialize_resources(resources.shared_resources()).run()
        variables.global_variables_initializer().run()
        variables.local_variables_initializer().run()

        # Gets the train_op and serialized proto of the ensemble.
        shared_resources = resources.shared_resources()
        self.assertEqual(1, len(shared_resources))
        train_op = estimator_spec.train_op
        with ops.control_dependencies([train_op]):
            _, ensemble_serialized = (
                gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
                    shared_resources[0].handle))
        return train_op, ensemble_serialized
Beispiel #4
0
    def serialize(self):
        """Serializes the ensemble into proto and returns the serialized proto.

    Returns:
      stamp_token: int64 scalar Tensor to denote the stamp of the resource.
      serialized_proto: string scalar Tensor of the serialized proto.
    """
        return gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
            self.resource_handle)
  def serialize(self):
    """Serializes the ensemble into proto and returns the serialized proto.

    Returns:
      stamp_token: int64 scalar Tensor to denote the stamp of the resource.
      serialized_proto: string scalar Tensor of the serialized proto.
    """
    return gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
        self.resource_handle)
  def __init__(self, resource_handle, create_op, name):
    """Creates a _TreeEnsembleSavable object.

    Args:
      resource_handle: handle to the decision tree ensemble variable.
      create_op: the op to initialize the variable.
      name: the name to save the tree ensemble variable under.
    """
    stamp_token, serialized = (
        gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle))
    # slice_spec is useful for saving a slice from a variable.
    # It's not meaningful the tree ensemble variable. So we just pass an empty
    # value.
    slice_spec = ''
    specs = [
        saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec,
                                        name + '_stamp'),
        saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec,
                                        name + '_serialized'),
    ]
    super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name)
    self._resource_handle = resource_handle
    self._create_op = create_op