def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
        """Runs prediction and returns a dictionary of the prediction results.

    Args:
      ensemble_handle: ensemble resource handle.
      ensemble_stamp: stamp of ensemble resource.
      mode: learn.ModeKeys.TRAIN or EVAL or INFER.

    Returns:
      a dictionary of prediction results -
        ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS,
        NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPED.
    """
        ensemble_stats = training_ops.tree_ensemble_stats(
            ensemble_handle, ensemble_stamp)
        num_handlers = (len(self._dense_floats) +
                        len(self._sparse_float_shapes) +
                        len(self._sparse_int_shapes))
        # Used during feature selection.
        used_handlers = model_ops.tree_ensemble_used_handlers(
            ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers)

        # We don't need dropout info - we can always restore it based on the
        # seed.
        apply_dropout, seed = _dropout_params(mode, ensemble_stats)
        # Make sure ensemble stats run. This will check that the ensemble has
        # the right stamp.
        with ops.control_dependencies(ensemble_stats):
            predictions, _ = prediction_ops.gradient_trees_prediction(
                ensemble_handle,
                seed,
                self._dense_floats,
                self._sparse_float_indices,
                self._sparse_float_values,
                self._sparse_float_shapes,
                self._sparse_int_indices,
                self._sparse_int_values,
                self._sparse_int_shapes,
                learner_config=self._learner_config_serialized,
                apply_dropout=apply_dropout,
                apply_averaging=mode != learn.ModeKeys.TRAIN,
                use_locking=True,
                center_bias=self._center_bias,
                reduce_dim=self._reduce_dim)
            partition_ids = prediction_ops.gradient_trees_partition_examples(
                ensemble_handle,
                self._dense_floats,
                self._sparse_float_indices,
                self._sparse_float_values,
                self._sparse_float_shapes,
                self._sparse_int_indices,
                self._sparse_int_values,
                self._sparse_int_shapes,
                use_locking=True)

        return _make_predictions_dict(ensemble_stamp, predictions,
                                      partition_ids, ensemble_stats,
                                      used_handlers)
Beispiel #2
0
  def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
    """Runs prediction and returns a dictionary of the prediction results.

    Args:
      ensemble_handle: ensemble resource handle.
      ensemble_stamp: stamp of ensemble resource.
      mode: learn.ModeKeys.TRAIN or EVAL or INFER.

    Returns:
      a dictionary of prediction results -
        ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS,
        NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPED.
    """
    ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle,
                                                      ensemble_stamp)
    num_handlers = (
        len(self._dense_floats) + len(self._sparse_float_shapes) +
        len(self._sparse_int_shapes))
    # Used during feature selection.
    used_handlers = model_ops.tree_ensemble_used_handlers(
        ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers)

    # We don't need dropout info - we can always restore it based on the
    # seed.
    apply_dropout, seed = _dropout_params(mode, ensemble_stats)
    # Make sure ensemble stats run. This will check that the ensemble has
    # the right stamp.
    with ops.control_dependencies(ensemble_stats):
      predictions, _ = prediction_ops.gradient_trees_prediction(
          ensemble_handle,
          seed,
          self._dense_floats,
          self._sparse_float_indices,
          self._sparse_float_values,
          self._sparse_float_shapes,
          self._sparse_int_indices,
          self._sparse_int_values,
          self._sparse_int_shapes,
          learner_config=self._learner_config_serialized,
          apply_dropout=apply_dropout,
          apply_averaging=mode != learn.ModeKeys.TRAIN,
          use_locking=True,
          center_bias=self._center_bias,
          reduce_dim=self._reduce_dim)
      partition_ids = prediction_ops.gradient_trees_partition_examples(
          ensemble_handle,
          self._dense_floats,
          self._sparse_float_indices,
          self._sparse_float_values,
          self._sparse_float_shapes,
          self._sparse_int_indices,
          self._sparse_int_values,
          self._sparse_int_shapes,
          use_locking=True)

    return _make_predictions_dict(ensemble_stamp, predictions, partition_ids,
                                  ensemble_stats, used_handlers)
Beispiel #3
0
        def _update_ensemble():
            """A method to update the tree ensemble."""
            # Get next stamp token.
            next_ensemble_stamp = ensemble_stamp + 1
            # Finalize bias stats.
            _, _, _, bias_grads, bias_hess = bias_stats_accumulator.flush(
                ensemble_stamp, next_ensemble_stamp)

            # Finalize handler splits.
            are_splits_ready_list = []
            partition_ids_list = []
            gains_list = []
            split_info_list = []

            for handler in handlers:
                (are_splits_ready, partition_ids, gains,
                 split_info) = handler.make_splits(ensemble_stamp,
                                                   next_ensemble_stamp,
                                                   class_id)
                are_splits_ready_list.append(are_splits_ready)
                partition_ids_list.append(partition_ids)
                gains_list.append(gains)
                split_info_list.append(split_info)
            # Stack all the inputs to one tensor per type.
            # This is a workaround for the slowness of graph building in tf.cond.
            # See (b/36554864).
            split_sizes = array_ops.stack([
                array_ops.shape(partition_id)[0]
                for partition_id in partition_ids_list
            ])
            partition_ids = array_ops.concat(partition_ids_list, axis=0)
            gains = array_ops.concat(gains_list, axis=0)
            split_infos = array_ops.concat(split_info_list, axis=0)

            # Determine if all splits are ready.
            are_all_splits_ready = math_ops.reduce_all(
                array_ops.stack(are_splits_ready_list,
                                axis=0,
                                name="stack_handler_readiness"))

            # Define bias centering update operation.
            def _center_bias_fn():
                # Center tree ensemble bias.
                delta_updates = array_ops.where(
                    bias_hess > 0, -bias_grads / bias_hess,
                    array_ops.zeros_like(bias_grads))
                center_bias = training_ops.center_tree_ensemble_bias(
                    tree_ensemble_handle=self._ensemble_handle,
                    stamp_token=ensemble_stamp,
                    next_stamp_token=next_ensemble_stamp,
                    delta_updates=delta_updates,
                    learner_config=self._learner_config_serialized)
                return continue_centering.assign(center_bias)

            # Define ensemble growing operations.
            def _grow_ensemble_ready_fn():
                # Grow the ensemble given the current candidates.
                sizes = array_ops.unstack(split_sizes)
                partition_ids_list = list(
                    array_ops.split(partition_ids, sizes, axis=0))
                gains_list = list(array_ops.split(gains, sizes, axis=0))
                split_info_list = list(
                    array_ops.split(split_infos, sizes, axis=0))
                return training_ops.grow_tree_ensemble(
                    tree_ensemble_handle=self._ensemble_handle,
                    stamp_token=ensemble_stamp,
                    next_stamp_token=next_ensemble_stamp,
                    learning_rate=learning_rate,
                    partition_ids=partition_ids_list,
                    gains=gains_list,
                    splits=split_info_list,
                    learner_config=self._learner_config_serialized,
                    dropout_seed=dropout_seed,
                    center_bias=self._center_bias)

            def _grow_ensemble_not_ready_fn():
                # Don't grow the ensemble, just update the stamp.
                return training_ops.grow_tree_ensemble(
                    tree_ensemble_handle=self._ensemble_handle,
                    stamp_token=ensemble_stamp,
                    next_stamp_token=next_ensemble_stamp,
                    learning_rate=0,
                    partition_ids=[],
                    gains=[],
                    splits=[],
                    learner_config=self._learner_config_serialized,
                    dropout_seed=dropout_seed,
                    center_bias=self._center_bias)

            def _grow_ensemble_fn():
                # Conditionally grow an ensemble depending on whether the splits
                # from all the handlers are ready.
                return control_flow_ops.cond(are_all_splits_ready,
                                             _grow_ensemble_ready_fn,
                                             _grow_ensemble_not_ready_fn)

            # Update ensemble.
            update_ops = [are_all_splits_ready]
            update_model = control_flow_ops.cond(continue_centering,
                                                 _center_bias_fn,
                                                 _grow_ensemble_fn)
            update_ops.append(update_model)

            # Update ensemble stats.
            with ops.control_dependencies([update_model]):
                stats = training_ops.tree_ensemble_stats(
                    self._ensemble_handle, stamp_token=next_ensemble_stamp)
                update_ops.append(self._finalized_trees.assign(
                    stats.num_trees))
                update_ops.append(
                    self._attempted_trees.assign(stats.attempted_trees))
                update_ops.append(num_layers.assign(stats.num_layers))
                update_ops.append(active_tree.assign(stats.active_tree))
                update_ops.append(active_layer.assign(stats.active_layer))

            # Flush step stats.
            update_ops.extend(
                steps_accumulator.flush(ensemble_stamp, next_ensemble_stamp))
            return control_flow_ops.group(*update_ops, name="update_ensemble")
Beispiel #4
0
    def _update_ensemble():
      """A method to update the tree ensemble."""
      # Get next stamp token.
      next_ensemble_stamp = ensemble_stamp + 1
      # Finalize bias stats.
      _, _, _, bias_grads, bias_hess = bias_stats_accumulator.flush(
          ensemble_stamp, next_ensemble_stamp)

      # Finalize handler splits.
      are_splits_ready_list = []
      partition_ids_list = []
      gains_list = []
      split_info_list = []

      for handler in handlers:
        (are_splits_ready,
         partition_ids, gains, split_info) = handler.make_splits(
             ensemble_stamp, next_ensemble_stamp, class_id)
        are_splits_ready_list.append(are_splits_ready)
        partition_ids_list.append(partition_ids)
        gains_list.append(gains)
        split_info_list.append(split_info)
      # Stack all the inputs to one tensor per type.
      # This is a workaround for the slowness of graph building in tf.cond.
      # See (b/36554864).
      split_sizes = array_ops.stack([
          array_ops.shape(partition_id)[0]
          for partition_id in partition_ids_list
      ])
      partition_ids = array_ops.concat(partition_ids_list, axis=0)
      gains = array_ops.concat(gains_list, axis=0)
      split_infos = array_ops.concat(split_info_list, axis=0)

      # Determine if all splits are ready.
      are_all_splits_ready = math_ops.reduce_all(
          array_ops.stack(
              are_splits_ready_list, axis=0, name="stack_handler_readiness"))

      # Define bias centering update operation.
      def _center_bias_fn():
        # Center tree ensemble bias.
        delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess,
                                        array_ops.zeros_like(bias_grads))
        center_bias = training_ops.center_tree_ensemble_bias(
            tree_ensemble_handle=self._ensemble_handle,
            stamp_token=ensemble_stamp,
            next_stamp_token=next_ensemble_stamp,
            delta_updates=delta_updates,
            learner_config=self._learner_config_serialized)
        return continue_centering.assign(center_bias)

      # Define ensemble growing operations.
      def _grow_ensemble_ready_fn():
        # Grow the ensemble given the current candidates.
        sizes = array_ops.unstack(split_sizes)
        partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0))
        gains_list = list(array_ops.split(gains, sizes, axis=0))
        split_info_list = list(array_ops.split(split_infos, sizes, axis=0))
        return training_ops.grow_tree_ensemble(
            tree_ensemble_handle=self._ensemble_handle,
            stamp_token=ensemble_stamp,
            next_stamp_token=next_ensemble_stamp,
            learning_rate=learning_rate,
            partition_ids=partition_ids_list,
            gains=gains_list,
            splits=split_info_list,
            learner_config=self._learner_config_serialized,
            dropout_seed=dropout_seed,
            center_bias=self._center_bias)

      def _grow_ensemble_not_ready_fn():
        # Don't grow the ensemble, just update the stamp.
        return training_ops.grow_tree_ensemble(
            tree_ensemble_handle=self._ensemble_handle,
            stamp_token=ensemble_stamp,
            next_stamp_token=next_ensemble_stamp,
            learning_rate=0,
            partition_ids=[],
            gains=[],
            splits=[],
            learner_config=self._learner_config_serialized,
            dropout_seed=dropout_seed,
            center_bias=self._center_bias)

      def _grow_ensemble_fn():
        # Conditionally grow an ensemble depending on whether the splits
        # from all the handlers are ready.
        return control_flow_ops.cond(are_all_splits_ready,
                                     _grow_ensemble_ready_fn,
                                     _grow_ensemble_not_ready_fn)

      # Update ensemble.
      update_ops = [are_all_splits_ready]
      update_model = control_flow_ops.cond(continue_centering, _center_bias_fn,
                                           _grow_ensemble_fn)
      update_ops.append(update_model)

      # Update ensemble stats.
      with ops.control_dependencies([update_model]):
        stats = training_ops.tree_ensemble_stats(
            self._ensemble_handle, stamp_token=next_ensemble_stamp)
        update_ops.append(self._finalized_trees.assign(stats.num_trees))
        update_ops.append(self._attempted_trees.assign(stats.attempted_trees))
        update_ops.append(num_layers.assign(stats.num_layers))
        update_ops.append(active_tree.assign(stats.active_tree))
        update_ops.append(active_layer.assign(stats.active_layer))

      # Flush step stats.
      update_ops.extend(
          steps_accumulator.flush(ensemble_stamp, next_ensemble_stamp))
      return control_flow_ops.group(*update_ops, name="update_ensemble")
Beispiel #5
0
  def predict(self, mode):
    """Returns predictions given the features and mode.

    Args:
      mode: Mode the graph is running in (train|predict|eval).

    Returns:
      A dict of predictions tensors.

    Raises:
      ValueError: if features is not valid.
    """
    apply_averaging = mode != learn.ModeKeys.TRAIN

    # Use the current ensemble to predict on the current batch of input.
    # For faster prediction we check if the inputs are on the same device
    # as the model. If not, we create a copy of the model on the worker.
    input_deps = (self._dense_floats + self._sparse_float_indices +
                  self._sparse_int_indices)
    if not input_deps:
      raise ValueError("No input tensors for prediction.")

    if any(i.device != input_deps[0].device for i in input_deps):
      raise ValueError("All input tensors should be on the same device.")

    # Get most current model stamp.
    ensemble_stamp = model_ops.tree_ensemble_stamp_token(self._ensemble_handle)

    # Determine if ensemble is colocated with the inputs.
    if self._ensemble_handle.device != input_deps[0].device:
      # Create a local ensemble and get its local stamp.
      with ops.name_scope("local_ensemble", "TreeEnsembleVariable") as name:
        local_ensemble_handle = (
            gen_model_ops.decision_tree_ensemble_resource_handle_op(name=name))
        create_op = gen_model_ops.create_tree_ensemble_variable(
            local_ensemble_handle, stamp_token=-1, tree_ensemble_config="")
        with ops.control_dependencies([create_op]):
          local_stamp = model_ops.tree_ensemble_stamp_token(
              local_ensemble_handle)

      # Determine whether the local ensemble is stale and update it if needed.
      def _refresh_local_ensemble_fn():
        # Serialize the model from parameter server after reading all inputs.
        with ops.control_dependencies(input_deps):
          (ensemble_stamp, serialized_model) = (
              model_ops.tree_ensemble_serialize(self._ensemble_handle))

        # Update local ensemble with the serialized model from parameter server.
        with ops.control_dependencies([create_op]):
          return model_ops.tree_ensemble_deserialize(
              local_ensemble_handle,
              stamp_token=ensemble_stamp,
              tree_ensemble_config=serialized_model), ensemble_stamp

      refresh_local_ensemble, ensemble_stamp = control_flow_ops.cond(
          math_ops.not_equal(ensemble_stamp,
                             local_stamp), _refresh_local_ensemble_fn,
          lambda: (control_flow_ops.no_op(), ensemble_stamp))

      # Once updated, Use the the local model for prediction.
      with ops.control_dependencies([refresh_local_ensemble]):
        ensemble_stats = training_ops.tree_ensemble_stats(
            local_ensemble_handle, ensemble_stamp)
        apply_dropout, seed = _dropout_params(mode, ensemble_stats)
        # We don't need dropout info - we can always restore it based on the
        # seed.
        predictions, predictions_no_dropout, _ = (
            prediction_ops.gradient_trees_prediction(
                local_ensemble_handle,
                seed,
                self._dense_floats,
                self._sparse_float_indices,
                self._sparse_float_values,
                self._sparse_float_shapes,
                self._sparse_int_indices,
                self._sparse_int_values,
                self._sparse_int_shapes,
                learner_config=self._learner_config_serialized,
                apply_dropout=apply_dropout,
                apply_averaging=apply_averaging,
                use_locking=False,
                center_bias=self._center_bias,
                reduce_dim=self._reduce_dim))
        partition_ids = prediction_ops.gradient_trees_partition_examples(
            local_ensemble_handle,
            self._dense_floats,
            self._sparse_float_indices,
            self._sparse_float_values,
            self._sparse_float_shapes,
            self._sparse_int_indices,
            self._sparse_int_values,
            self._sparse_int_shapes,
            use_locking=False)

    else:
      with ops.device(self._ensemble_handle.device):
        ensemble_stats = training_ops.tree_ensemble_stats(
            self._ensemble_handle, ensemble_stamp)
        apply_dropout, seed = _dropout_params(mode, ensemble_stats)
        # We don't need dropout info - we can always restore it based on the
        # seed.
        predictions, predictions_no_dropout, _ = (
            prediction_ops.gradient_trees_prediction(
                self._ensemble_handle,
                seed,
                self._dense_floats,
                self._sparse_float_indices,
                self._sparse_float_values,
                self._sparse_float_shapes,
                self._sparse_int_indices,
                self._sparse_int_values,
                self._sparse_int_shapes,
                learner_config=self._learner_config_serialized,
                apply_dropout=apply_dropout,
                apply_averaging=apply_averaging,
                use_locking=False,
                center_bias=self._center_bias,
                reduce_dim=self._reduce_dim))
        partition_ids = prediction_ops.gradient_trees_partition_examples(
            self._ensemble_handle,
            self._dense_floats,
            self._sparse_float_indices,
            self._sparse_float_values,
            self._sparse_float_shapes,
            self._sparse_int_indices,
            self._sparse_int_values,
            self._sparse_int_shapes,
            use_locking=False)

    return _make_predictions_dict(ensemble_stamp, predictions,
                                  predictions_no_dropout, partition_ids,
                                  ensemble_stats)