コード例 #1
0
ファイル: custom_estimators.py プロジェクト: ylzhang29/ezeeai
def _linear(units, features, feature_columns, sparse_combiner='sum'):
    linear_model = feature_column._LinearModel(  # pylint: disable=protected-access
        feature_columns=feature_columns,
        units=units,
        sparse_combiner=sparse_combiner,
        name='linear_model')
    output = linear_model(features)
    return output
コード例 #2
0
    def linear_logit_fn(features):
        """Linear model logit_fn.

    Args:
      features: This is the first item returned from the `input_fn`
                passed to `train`, `evaluate`, and `predict`. This should be a
                single `Tensor` or `dict` of same.

    Returns:
      A `Tensor` representing the logits.
    """
        if feature_column_v2.is_feature_column_v2(feature_columns):
            shared_state_manager = feature_column_v2.SharedEmbeddingStateManager(
            )
            linear_model = feature_column_v2.LinearModel(
                feature_columns=feature_columns,
                units=units,
                sparse_combiner=sparse_combiner,
                shared_state_manager=shared_state_manager)
            logits = linear_model(features)
            bias = linear_model.bias_variable

            # We'd like to get all the non-bias variables associated with this
            # LinearModel. This includes the shared embedding variables as well.
            variables = linear_model.variables
            variables.remove(bias)
            variables.extend(shared_state_manager.variables)

            # Expand (potential) Partitioned variables
            bias = _get_expanded_variable_list([bias])
            variables = _get_expanded_variable_list(variables)
        else:
            linear_model = feature_column._LinearModel(  # pylint: disable=protected-access
                feature_columns=feature_columns,
                units=units,
                sparse_combiner=sparse_combiner,
                name='linear_model')
            logits = linear_model(features)
            cols_to_vars = linear_model.cols_to_vars()
            bias = cols_to_vars.pop('bias')
            variables = cols_to_vars.values()

        if units > 1:
            summary.histogram('bias', bias)
        else:
            # If units == 1, the bias value is a length-1 list of a scalar Tensor,
            # so we should provide a scalar summary.
            summary.scalar('bias', bias[0][0])
        summary.scalar('fraction_of_zero_weights',
                       _compute_fraction_of_zero(variables))
        return logits
コード例 #3
0
ファイル: linear.py プロジェクト: ThunderQi/tensorflow
  def linear_logit_fn(features):
    """Linear model logit_fn.

    Args:
      features: This is the first item returned from the `input_fn`
                passed to `train`, `evaluate`, and `predict`. This should be a
                single `Tensor` or `dict` of same.

    Returns:
      A `Tensor` representing the logits.
    """
    if feature_column_v2.is_feature_column_v2(feature_columns):
      shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
      linear_model = feature_column_v2.LinearModel(
          feature_columns=feature_columns,
          units=units,
          sparse_combiner=sparse_combiner,
          shared_state_manager=shared_state_manager)
      logits = linear_model(features)
      bias = linear_model.bias_variable

      # We'd like to get all the non-bias variables associated with this
      # LinearModel. This includes the shared embedding variables as well.
      variables = linear_model.variables
      variables.remove(bias)
      variables.extend(shared_state_manager.variables)

      # Expand (potential) Partitioned variables
      bias = _get_expanded_variable_list([bias])
      variables = _get_expanded_variable_list(variables)
    else:
      linear_model = feature_column._LinearModel(  # pylint: disable=protected-access
          feature_columns=feature_columns,
          units=units,
          sparse_combiner=sparse_combiner,
          name='linear_model')
      logits = linear_model(features)
      cols_to_vars = linear_model.cols_to_vars()
      bias = cols_to_vars.pop('bias')
      variables = cols_to_vars.values()

    if units > 1:
      summary.histogram('bias', bias)
    else:
      # If units == 1, the bias value is a length-1 list of a scalar Tensor,
      # so we should provide a scalar summary.
      summary.scalar('bias', bias[0][0])
    summary.scalar('fraction_of_zero_weights',
                   _compute_fraction_of_zero(variables))
    return logits
コード例 #4
0
 def _get_keras_linear_model_predictions(self,
                                         features,
                                         feature_columns,
                                         units=1,
                                         sparse_combiner='sum',
                                         weight_collections=None,
                                         trainable=True,
                                         cols_to_vars=None):
     keras_linear_model = _LinearModel(feature_columns,
                                       units,
                                       sparse_combiner,
                                       weight_collections,
                                       trainable,
                                       name='linear_model')
     retval = keras_linear_model(features)  # pylint: disable=not-callable
     if cols_to_vars is not None:
         cols_to_vars.update(keras_linear_model.cols_to_vars())
     return retval
コード例 #5
0
 def _get_keras_linear_model_predictions(
     self,
     features,
     feature_columns,
     units=1,
     sparse_combiner='sum',
     weight_collections=None,
     trainable=True,
     cols_to_vars=None):
   keras_linear_model = _LinearModel(
       feature_columns,
       units,
       sparse_combiner,
       weight_collections,
       trainable,
       name='linear_model')
   retval = keras_linear_model(features)  # pylint: disable=not-callable
   if cols_to_vars is not None:
     cols_to_vars.update(keras_linear_model.cols_to_vars())
   return retval