Esempio n. 1
0
    def testCreate(self):
        with self.test_session():
            tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
            tree = tree_ensemble_config.trees.add()
            _append_to_leaf(tree.nodes.add().leaf, 0, -0.4)
            tree_ensemble_config.tree_weights.append(1.0)

            # Prepare learner config.
            learner_config = learner_pb2.LearnerConfig()
            learner_config.num_classes = 2

            tree_ensemble_handle = model_ops.tree_ensemble_variable(
                stamp_token=3,
                tree_ensemble_config=tree_ensemble_config.SerializeToString(),
                name="create_tree")
            resources.initialize_resources(resources.shared_resources()).run()

            result, _, _ = prediction_ops.gradient_trees_prediction(
                tree_ensemble_handle,
                self._seed, [self._dense_float_tensor],
                [self._sparse_float_indices1, self._sparse_float_indices2],
                [self._sparse_float_values1, self._sparse_float_values2],
                [self._sparse_float_shape1, self._sparse_float_shape2],
                [self._sparse_int_indices1], [self._sparse_int_values1],
                [self._sparse_int_shape1],
                learner_config=learner_config.SerializeToString(),
                apply_dropout=False,
                apply_averaging=False,
                center_bias=False,
                reduce_dim=True)
            self.assertAllClose(result.eval(), [[-0.4], [-0.4]])
            stamp_token = model_ops.tree_ensemble_stamp_token(
                tree_ensemble_handle)
            self.assertEqual(stamp_token.eval(), 3)
Esempio n. 2
0
  def testCreate(self):
    with self.cached_session():
      tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
      tree = tree_ensemble_config.trees.add()
      _append_to_leaf(tree.nodes.add().leaf, 0, -0.4)
      tree_ensemble_config.tree_weights.append(1.0)

      # Prepare learner config.
      learner_config = learner_pb2.LearnerConfig()
      learner_config.num_classes = 2

      tree_ensemble_handle = model_ops.tree_ensemble_variable(
          stamp_token=3,
          tree_ensemble_config=tree_ensemble_config.SerializeToString(),
          name="create_tree")
      resources.initialize_resources(resources.shared_resources()).run()

      result, _ = prediction_ops.gradient_trees_prediction(
          tree_ensemble_handle,
          self._seed, [self._dense_float_tensor], [
              self._sparse_float_indices1, self._sparse_float_indices2
          ], [self._sparse_float_values1, self._sparse_float_values2],
          [self._sparse_float_shape1,
           self._sparse_float_shape2], [self._sparse_int_indices1],
          [self._sparse_int_values1], [self._sparse_int_shape1],
          learner_config=learner_config.SerializeToString(),
          apply_dropout=False,
          apply_averaging=False,
          center_bias=False,
          reduce_dim=True)
      self.assertAllClose(result.eval(), [[-0.4], [-0.4]])
      stamp_token = model_ops.tree_ensemble_stamp_token(tree_ensemble_handle)
      self.assertEqual(stamp_token.eval(), 3)
Esempio n. 3
0
    def predict(self, mode):
        """Returns predictions given the features and mode.

    Args:
      mode: Mode the graph is running in (train|predict|eval).

    Returns:
      A dict of predictions tensors.

    Raises:
      ValueError: if features is not valid.
    """

        # Use the current ensemble to predict on the current batch of input.
        # For faster prediction we check if the inputs are on the same device
        # as the model. If not, we create a copy of the model on the worker.
        input_deps = (self._dense_floats + self._sparse_float_indices +
                      self._sparse_int_indices)
        if not input_deps:
            raise ValueError("No input tensors for prediction.")

        if any(i.device != input_deps[0].device for i in input_deps):
            raise ValueError("All input tensors should be on the same device.")

        # Get most current model stamp.
        ensemble_stamp = model_ops.tree_ensemble_stamp_token(
            self._ensemble_handle)

        # Determine if ensemble is colocated with the inputs.
        if self._ensemble_handle.device != input_deps[0].device:
            # Create a local ensemble and get its local stamp.
            with ops.name_scope("local_ensemble",
                                "TreeEnsembleVariable") as name:
                local_ensemble_handle = (
                    gen_model_ops.decision_tree_ensemble_resource_handle_op(
                        name=name))
                create_op = gen_model_ops.create_tree_ensemble_variable(
                    local_ensemble_handle,
                    stamp_token=-1,
                    tree_ensemble_config="")
                with ops.control_dependencies([create_op]):
                    local_stamp = model_ops.tree_ensemble_stamp_token(
                        local_ensemble_handle)

            # Determine whether the local ensemble is stale and update it if needed.
            def _refresh_local_ensemble_fn():
                # Serialize the model from parameter server after reading all inputs.
                with ops.control_dependencies(input_deps):
                    (ensemble_stamp,
                     serialized_model) = (model_ops.tree_ensemble_serialize(
                         self._ensemble_handle))

                # Update local ensemble with the serialized model from parameter server.
                with ops.control_dependencies([create_op]):
                    return model_ops.tree_ensemble_deserialize(
                        local_ensemble_handle,
                        stamp_token=ensemble_stamp,
                        tree_ensemble_config=serialized_model), ensemble_stamp

            refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond(
                math_ops.not_equal(ensemble_stamp,
                                   local_stamp), _refresh_local_ensemble_fn,
                lambda: (control_flow_ops.no_op(), ensemble_stamp))

            # Once updated, use the local model for prediction.
            with ops.control_dependencies([refresh_local_ensemble]):
                return self._predict_and_return_dict(local_ensemble_handle,
                                                     ensemble_stamp, mode)
        else:
            # Use ensemble_handle directly, if colocated.
            with ops.device(self._ensemble_handle.device):
                return self._predict_and_return_dict(self._ensemble_handle,
                                                     ensemble_stamp, mode)
Esempio n. 4
0
  def predict(self, mode):
    """Returns predictions given the features and mode.

    Args:
      mode: Mode the graph is running in (train|predict|eval).

    Returns:
      A dict of predictions tensors.

    Raises:
      ValueError: if features is not valid.
    """
    apply_averaging = mode != learn.ModeKeys.TRAIN

    # Use the current ensemble to predict on the current batch of input.
    # For faster prediction we check if the inputs are on the same device
    # as the model. If not, we create a copy of the model on the worker.
    input_deps = (self._dense_floats + self._sparse_float_indices +
                  self._sparse_int_indices)
    if not input_deps:
      raise ValueError("No input tensors for prediction.")

    if any(i.device != input_deps[0].device for i in input_deps):
      raise ValueError("All input tensors should be on the same device.")

    # Get most current model stamp.
    ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle)

    # Determine if ensemble is colocated with the inputs.
    if self._ensemble_handle.device != input_deps[0].device:
      # Create a local ensemble and get its local stamp.
      with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name:
        local_ensemble_handle = (
            gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name))
        create_op = gen_model_ops.create_tree_ensemble_variable(
            local_ensemble_handle, stamp_token=-1, tree_ensemble_config="")
        with ops.control_dependencies([create_op]):
          local_stamp = model_ops.tree_ensemble_stamp_token(
              local_ensemble_handle)

      # Determine whether the local ensemble is stale and update it if needed.
      def _refresh_local_ensemble_fn():
        # Serialize the model from parameter server after reading all inputs.
        with ops.control_dependencies(input_deps):
          (ensemble_stamp, serialized_model) = (
              model_ops.tree_ensemble_serialize(self._ensemble_handle))

        # Update local ensemble with the serialized model from parameter server.
        with ops.control_dependencies([create_op]):
          return model_ops.tree_ensemble_deserialize(
              local_ensemble_handle,
              stamp_token=ensemble_stamp,
              tree_ensemble_config=serialized_model), ensemble_stamp

      refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond(
          math_ops.not_equal(ensemble_stamp,
                             local_stamp), _refresh_local_ensemble_fn,
          lambda: (control_flow_ops.no_op(), ensemble_stamp))

      # Once updated, Use the the local model for prediction.
      with ops.control_dependencies([refresh_local_ensemble]):
        ensemble_stats = training_ops.tree_ensemble_stats(
            local_ensemble_handle, ensemble_stamp)
        apply_dropout, seed = _dropout_params(mode, ensemble_stats)
        # We don't need dropout info - we can always restore it based on the
        # seed.
        predictions, predictions_no_dropout, _ = (
            prediction_ops.gradient_trees_prediction(
                local_ensemble_handle,
                seed,
                self._dense_floats,
                self._sparse_float_indices,
                self._sparse_float_values,
                self._sparse_float_shapes,
                self._sparse_int_indices,
                self._sparse_int_values,
                self._sparse_int_shapes,
                learner_config=self._learner_config_serialized,
                apply_dropout=apply_dropout,
                apply_averaging=apply_averaging,
                use_locking=False,
                center_bias=self._center_bias,
                reduce_dim=self._reduce_dim))
        partition_ids = prediction_ops.gradient_trees_partition_examples(
            local_ensemble_handle,
            self._dense_floats,
            self._sparse_float_indices,
            self._sparse_float_values,
            self._sparse_float_shapes,
            self._sparse_int_indices,
            self._sparse_int_values,
            self._sparse_int_shapes,
            use_locking=False)

    else:
      with ops.device(self._ensemble_handle.device):
        ensemble_stats = training_ops.tree_ensemble_stats(
            self._ensemble_handle, ensemble_stamp)
        apply_dropout, seed = _dropout_params(mode, ensemble_stats)
        # We don't need dropout info - we can always restore it based on the
        # seed.
        predictions, predictions_no_dropout, _ = (
            prediction_ops.gradient_trees_prediction(
                self._ensemble_handle,
                seed,
                self._dense_floats,
                self._sparse_float_indices,
                self._sparse_float_values,
                self._sparse_float_shapes,
                self._sparse_int_indices,
                self._sparse_int_values,
                self._sparse_int_shapes,
                learner_config=self._learner_config_serialized,
                apply_dropout=apply_dropout,
                apply_averaging=apply_averaging,
                use_locking=False,
                center_bias=self._center_bias,
                reduce_dim=self._reduce_dim))
        partition_ids = prediction_ops.gradient_trees_partition_examples(
            self._ensemble_handle,
            self._dense_floats,
            self._sparse_float_indices,
            self._sparse_float_values,
            self._sparse_float_shapes,
            self._sparse_int_indices,
            self._sparse_int_values,
            self._sparse_int_shapes,
            use_locking=False)

    return _make_predictions_dict(ensemble_stamp, predictions,
                                  predictions_no_dropout, partition_ids,
                                  ensemble_stats)
Esempio n. 5
0
  def predict(self, mode):
    """Returns predictions given the features and mode.

    Args:
      mode: Mode the graph is running in (train|predict|eval).

    Returns:
      A dict of predictions tensors.

    Raises:
      ValueError: if features is not valid.
    """

    # Use the current ensemble to predict on the current batch of input.
    # For faster prediction we check if the inputs are on the same device
    # as the model. If not, we create a copy of the model on the worker.
    input_deps = (self._dense_floats + self._sparse_float_indices +
                  self._sparse_int_indices)
    if not input_deps:
      raise ValueError("No input tensors for prediction.")

    if any(i.device != input_deps[0].device for i in input_deps):
      raise ValueError("All input tensors should be on the same device.")

    # Get most current model stamp.
    ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle)

    # Determine if ensemble is colocated with the inputs.
    if self._ensemble_handle.device != input_deps[0].device:
      # Create a local ensemble and get its local stamp.
      with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name:
        local_ensemble_handle = (
            gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name))
        create_op = gen_model_ops.create_tree_ensemble_variable(
            local_ensemble_handle, stamp_token=-1, tree_ensemble_config="")
        with ops.control_dependencies([create_op]):
          local_stamp = model_ops.tree_ensemble_stamp_token(
              local_ensemble_handle)

      # Determine whether the local ensemble is stale and update it if needed.
      def _refresh_local_ensemble_fn():
        # Serialize the model from parameter server after reading all inputs.
        with ops.control_dependencies(input_deps):
          (ensemble_stamp, serialized_model) = (
              model_ops.tree_ensemble_serialize(self._ensemble_handle))

        # Update local ensemble with the serialized model from parameter server.
        with ops.control_dependencies([create_op]):
          return model_ops.tree_ensemble_deserialize(
              local_ensemble_handle,
              stamp_token=ensemble_stamp,
              tree_ensemble_config=serialized_model), ensemble_stamp

      refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond(
          math_ops.not_equal(ensemble_stamp,
                             local_stamp), _refresh_local_ensemble_fn,
          lambda: (control_flow_ops.no_op(), ensemble_stamp))

      # Once updated, use the local model for prediction.
      with ops.control_dependencies([refresh_local_ensemble]):
        return self._predict_and_return_dict(local_ensemble_handle,
                                             ensemble_stamp, mode)
    else:
      # Use ensemble_handle directly, if colocated.
      with ops.device(self._ensemble_handle.device):
        return self._predict_and_return_dict(self._ensemble_handle,
                                             ensemble_stamp, mode)