コード例 #1
0
def _get_default_head(params, weights_name, output_type, name=None):
    """Creates a default head based on a type of a problem."""
    if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
        if params.regression:
            return head_lib.regression_head(weight_column_name=weights_name,
                                            label_dimension=params.num_outputs,
                                            enable_centered_bias=False,
                                            head_name=name)
        else:
            return head_lib.multi_class_head(params.num_classes,
                                             weight_column_name=weights_name,
                                             enable_centered_bias=False,
                                             head_name=name)
    else:
        if params.regression:
            return core_head_lib.regression_head(
                weight_column=weights_name,
                label_dimension=params.num_outputs,
                name=name,
                loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
        else:
            if params.num_classes == 2:
                return core_head_lib.binary_classification_head(
                    weight_column=weights_name,
                    name=name,
                    loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
            else:
                return core_head_lib.multi_class_head(
                    n_classes=params.num_classes,
                    weight_column=weights_name,
                    name=name,
                    loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
コード例 #2
0
    def testJointLinearModel(self):
        """Tests that loss goes down with training."""
        def input_fn():
            return {
                'age':
                sparse_tensor.SparseTensor(values=['1'],
                                           indices=[[0, 0]],
                                           dense_shape=[1, 1]),
                'language':
                sparse_tensor.SparseTensor(values=['english'],
                                           indices=[[0, 0]],
                                           dense_shape=[1, 1])
            }, constant_op.constant([[1]])

        language = feature_column.sparse_column_with_hash_bucket(
            'language', 100)
        age = feature_column.sparse_column_with_hash_bucket('age', 2)

        head = head_lib.multi_class_head(n_classes=2)
        classifier = _joint_linear_estimator(head,
                                             feature_columns=[age, language])

        classifier.fit(input_fn=input_fn, steps=1000)
        loss1 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
        classifier.fit(input_fn=input_fn, steps=2000)
        loss2 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
        self.assertLess(loss2, loss1)
        self.assertLess(loss2, 0.01)
コード例 #3
0
    def __init__(self,
                 example_id_column,
                 feature_columns,
                 weight_column_name=None,
                 model_dir=None,
                 l1_regularization=0.0,
                 l2_regularization=1.0,
                 num_loss_partitions=None,
                 config=None,
                 feature_engineering_fn=None,
                 partitioner=None):
        """Construct a `SDCALogisticClassifier` object.

    Args:
      example_id_column: A string defining the feature column name representing
        example ids. Used to initialize the underlying SDCA optimizer.
      feature_columns: An iterable containing all the feature columns used by
        the model. All items in the iterable should derive from `FeatureColumn`.
        Note that the order of the items is ignored at model construction time.
      weight_column_name: A string defining feature column name representing
        weights. It is used to downweight or boost examples during training. It
        will be multiplied by the loss of the example.
      model_dir: Directory to save model parameters, graph etc. This can also be
        used to load checkpoints from the directory into an estimator to
        continue training a previously saved model.
      l1_regularization: L1-regularization parameter. Refers to global L1
        regularization (across all examples).
      l2_regularization: L2-regularization parameter. Refers to global L2
        regularization (across all examples).
      num_loss_partitions: Number of partitions of the global loss function
        optimized by the underlying optimizer (SDCAOptimizer).
      config: `RunConfig` object to configure the runtime settings.
      feature_engineering_fn: Feature engineering function. Takes features and
        labels which are the output of `input_fn` and returns features and
        labels which will be fed into the model.
      partitioner: Variable partitioner for the primal weights (`div`
        partitioning strategy will be used).

    Returns:
      A `SDCALogisiticClassifier` estimator.
    """
        super(SDCALogisticClassifier,
              self).__init__(example_id_column=example_id_column,
                             feature_columns=feature_columns,
                             weight_column_name=weight_column_name,
                             model_dir=model_dir,
                             head=head_lib.multi_class_head(
                                 n_classes=2,
                                 weight_column_name=weight_column_name),
                             l1_regularization=l1_regularization,
                             l2_regularization=l2_regularization,
                             num_loss_partitions=num_loss_partitions,
                             config=config,
                             feature_engineering_fn=None,
                             partitioner=partitioner)
コード例 #4
0
    def __init__(self,
                 feature_columns=None,
                 model_dir=None,
                 n_classes=2,
                 weight_column_name=None,
                 optimizer=None,
                 kernel_mappers=None,
                 config=None):
        """Construct a `KernelLinearClassifier` estimator object.

    Args:
      feature_columns: An iterable containing all the feature columns used by
        the model. All items in the set should be instances of classes derived
        from `FeatureColumn`.
      model_dir: Directory to save model parameters, graph etc. This can also be
        used to load checkpoints from the directory into an estimator to
        continue training a previously saved model.
      n_classes: number of label classes. Default is binary classification.
        Note that class labels are integers representing the class index (i.e.
        values from 0 to n_classes-1). For arbitrary label values (e.g. string
        labels), convert to class indices first.
      weight_column_name: A string defining feature column name representing
        weights. It is used to down weight or boost examples during training. It
        will be multiplied by the loss of the example.
      optimizer: The optimizer used to train the model. If specified, it should
        be an instance of `tf.Optimizer`. If `None`, the Ftrl optimizer is used
        by default.
      kernel_mappers: Dictionary of kernel mappers to be applied to the input
        features before training a (linear) model. Keys are feature columns and
        values are lists of mappers to be applied to the corresponding feature
        column. Currently only _RealValuedColumns are supported and therefore
        all mappers should conform to the `DenseKernelMapper` interface (see
        ./mappers/dense_kernel_mapper.py).
      config: `RunConfig` object to configure the runtime settings.

    Returns:
      A `KernelLinearClassifier` estimator.

    Raises:
      ValueError: if n_classes < 2.
      ValueError: if neither feature_columns nor kernel_mappers are provided.
      ValueError: if mappers provided as kernel_mappers values are invalid.
    """
        super(KernelLinearClassifier,
              self).__init__(feature_columns=feature_columns,
                             model_dir=model_dir,
                             weight_column_name=weight_column_name,
                             head=head_lib.multi_class_head(
                                 n_classes=n_classes,
                                 weight_column_name=weight_column_name),
                             optimizer=optimizer,
                             kernel_mappers=kernel_mappers,
                             config=config)
コード例 #5
0
    def testDNNModel(self):
        """Tests multi-class classification using matrix data as input."""
        cont_features = [
            feature_column.real_valued_column('feature', dimension=4)
        ]

        head = head_lib.multi_class_head(n_classes=3)
        classifier = _dnn_estimator(head,
                                    feature_columns=cont_features,
                                    hidden_units=[3, 3])

        classifier.fit(input_fn=_iris_input_fn, steps=1000)
        classifier.evaluate(input_fn=_iris_input_fn, steps=100)
コード例 #6
0
    def __init__(self,
                 model_dir=None,
                 n_classes=2,
                 weight_column_name=None,
                 config=None,
                 feature_engineering_fn=None,
                 label_keys=None):
        """Initializes a DebugClassifier instance.

    Args:
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model.
      n_classes: number of label classes. Default is binary classification.
        It must be greater than 1. Note: Class labels are integers representing
        the class index (i.e. values from 0 to n_classes-1). For arbitrary
        label values (e.g. string labels), convert to class indices first.
      weight_column_name: A string defining feature column name representing
        weights. It is used to down weight or boost examples during training. It
        will be multiplied by the loss of the example.
      config: `RunConfig` object to configure the runtime settings.
      feature_engineering_fn: Feature engineering function. Takes features and
                        labels which are the output of `input_fn` and returns
                        features and labels which will be fed into the model.
      label_keys: Optional list of strings with size `[n_classes]` defining the
        label vocabulary. Only supported for `n_classes` > 2.
    Returns:
      A `DebugClassifier` estimator.

    Raises:
      ValueError: If `n_classes` < 2.
    """
        params = {
            "head":
            head_lib.multi_class_head(n_classes=n_classes,
                                      weight_column_name=weight_column_name,
                                      enable_centered_bias=True,
                                      label_keys=label_keys)
        }

        super(DebugClassifier,
              self).__init__(model_fn=debug_model_fn,
                             model_dir=model_dir,
                             config=config,
                             params=params,
                             feature_engineering_fn=feature_engineering_fn)
    def __init__(self,
                 dnn_hidden_units,
                 dnn_feature_columns,
                 tree_learner_config,
                 num_trees,
                 tree_examples_per_layer,
                 n_classes=2,
                 weight_column_name=None,
                 model_dir=None,
                 config=None,
                 label_name=None,
                 label_keys=None,
                 feature_engineering_fn=None,
                 dnn_optimizer="Adagrad",
                 dnn_activation_fn=nn.relu,
                 dnn_dropout=None,
                 dnn_input_layer_partitioner=None,
                 dnn_input_layer_to_tree=True,
                 dnn_steps_to_train=10000,
                 predict_with_tree_only=False,
                 tree_feature_columns=None,
                 tree_center_bias=False,
                 dnn_to_tree_distillation_param=None,
                 use_core_versions=False,
                 override_global_step_value=None):
        """Initializes a DNNBoostedTreeCombinedClassifier instance.

    Args:
      dnn_hidden_units: List of hidden units per layer for DNN.
      dnn_feature_columns: An iterable containing all the feature columns
        used by the model's DNN.
      tree_learner_config: A config for the tree learner.
      num_trees: Number of trees to grow model to after training DNN.
      tree_examples_per_layer: Number of examples to accumulate before
        growing the tree a layer. This value has a big impact on model
        quality and should be set equal to the number of examples in
        training dataset if possible. It can also be a function that computes
        the number of examples based on the depth of the layer that's
        being built.
      n_classes: The number of label classes.
      weight_column_name: The name of weight column.
      model_dir: Directory for model exports.
      config: `RunConfig` of the estimator.
      label_name: String, name of the key in label dict. Can be null if label
        is a tensor (single headed models).
      label_keys: Optional list of strings with size `[n_classes]` defining the
        label vocabulary. Only supported for `n_classes` > 2.
      feature_engineering_fn: Feature engineering function. Takes features and
        labels which are the output of `input_fn` and returns features and
        labels which will be fed into the model.
      dnn_optimizer: string, `Optimizer` object, or callable that defines the
        optimizer to use for training the DNN. If `None`, will use the Adagrad
        optimizer with default learning rate.
      dnn_activation_fn: Activation function applied to each layer of the DNN.
        If `None`, will use `tf.nn.relu`.
      dnn_dropout: When not `None`, the probability to drop out a given
        unit in the DNN.
      dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
        Defaults to `min_max_variable_partitioner` with `min_slice_size`
        64 << 20.
      dnn_input_layer_to_tree: Whether to provide the DNN's input layer
      as a feature to the tree.
      dnn_steps_to_train: Number of steps to train dnn for before switching
        to gbdt.
      predict_with_tree_only: Whether to use only the tree model output as the
        final prediction.
      tree_feature_columns: An iterable containing all the feature columns
        used by the model's boosted trees. If dnn_input_layer_to_tree is
        set to True, these features are in addition to dnn_feature_columns.
      tree_center_bias: Whether a separate tree should be created for
        first fitting the bias.
      dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
        float defines the weight of the distillation loss, and the loss_fn, for
        computing distillation loss, takes dnn_logits, tree_logits and weight
        tensor. If the entire tuple is None, no distillation will be applied. If
        only the loss_fn is None, we will take the sigmoid/softmax cross entropy
        loss be default. When distillation is applied, `predict_with_tree_only`
        will be set to True.
      use_core_versions: Whether feature columns and loss are from the core (as
        opposed to contrib) version of tensorflow.
      override_global_step_value: If after the training is done, global step
        value must be reset to this value. This is particularly useful for hyper
        parameter tuning, which can't recognize early stopping due to the number
        of trees. If None, no override of global step will happen.
    """
        head = head_lib.multi_class_head(n_classes=n_classes,
                                         label_name=label_name,
                                         label_keys=label_keys,
                                         weight_column_name=weight_column_name,
                                         enable_centered_bias=False)

        def _model_fn(features, labels, mode, config):
            return _dnn_tree_combined_model_fn(
                features=features,
                labels=labels,
                mode=mode,
                head=head,
                dnn_hidden_units=dnn_hidden_units,
                dnn_feature_columns=dnn_feature_columns,
                tree_learner_config=tree_learner_config,
                num_trees=num_trees,
                tree_examples_per_layer=tree_examples_per_layer,
                config=config,
                dnn_optimizer=dnn_optimizer,
                dnn_activation_fn=dnn_activation_fn,
                dnn_dropout=dnn_dropout,
                dnn_input_layer_partitioner=dnn_input_layer_partitioner,
                dnn_input_layer_to_tree=dnn_input_layer_to_tree,
                dnn_steps_to_train=dnn_steps_to_train,
                predict_with_tree_only=predict_with_tree_only,
                tree_feature_columns=tree_feature_columns,
                tree_center_bias=tree_center_bias,
                dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
                use_core_versions=use_core_versions,
                override_global_step_value=override_global_step_value)

        super(DNNBoostedTreeCombinedClassifier,
              self).__init__(model_fn=_model_fn,
                             model_dir=model_dir,
                             config=config,
                             feature_engineering_fn=feature_engineering_fn)
コード例 #8
0
    def __init__(self,
                 learner_config,
                 examples_per_layer,
                 n_classes=2,
                 num_trees=None,
                 feature_columns=None,
                 weight_column_name=None,
                 model_dir=None,
                 config=None,
                 label_keys=None,
                 feature_engineering_fn=None,
                 logits_modifier_function=None,
                 center_bias=True,
                 use_core_libs=False,
                 output_leaf_index=False,
                 override_global_step_value=None,
                 num_quantiles=100):
        """Initializes a GradientBoostedDecisionTreeClassifier estimator instance.

    Args:
      learner_config: A config for the learner.
      examples_per_layer: Number of examples to accumulate before growing a
        layer. It can also be a function that computes the number of examples
        based on the depth of the layer that's being built.
      n_classes: Number of classes in the classification.
      num_trees: An int, number of trees to build.
      feature_columns: A list of feature columns.
      weight_column_name: Name of the column for weights, or None if not
        weighted.
      model_dir: Directory for model exports, etc.
      config: `RunConfig` object to configure the runtime settings.
      label_keys: Optional list of strings with size `[n_classes]` defining the
        label vocabulary. Only supported for `n_classes` > 2.
      feature_engineering_fn: Feature engineering function. Takes features and
        labels which are the output of `input_fn` and returns features and
        labels which will be fed into the model.
      logits_modifier_function: A modifier function for the logits.
      center_bias: Whether a separate tree should be created for first fitting
        the bias.
      use_core_libs: Whether feature columns and loss are from the core (as
        opposed to contrib) version of tensorflow.
      output_leaf_index: whether to output leaf indices along with predictions
        during inference. The leaf node indexes are available in predictions
        dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
        [batch_size, num_trees]. For example, result_iter =
        classifier.predict(...)
        for result_dict in result_iter: # access leaf index list by
          result_dict["leaf_index"] # which contains one leaf index per tree
      override_global_step_value: If after the training is done, global step
        value must be reset to this value. This should be used to reset global
        step to a number > number of steps used to train the current ensemble.
        For example, the usual way is to train a number of trees and set a very
        large number of training steps. When the training is done (number of
        trees were trained), this parameter can be used to set the global step
        to a large value, making it look like that number of training steps ran.
        If None, no override of global step will happen.
      num_quantiles: Number of quantiles to build for numeric feature values.

    Raises:
      ValueError: If learner_config is not valid.
    """
        if n_classes > 2:
            # For multi-class classification, use our loss implementation that
            # supports second order derivative.
            def loss_fn(labels, logits, weights=None):
                result = losses.per_example_maxent_loss(labels=labels,
                                                        logits=logits,
                                                        weights=weights,
                                                        num_classes=n_classes)
                return math_ops.reduce_mean(result[0])
        else:
            loss_fn = None
        head = head_lib.multi_class_head(n_classes=n_classes,
                                         weight_column_name=weight_column_name,
                                         enable_centered_bias=False,
                                         loss_fn=loss_fn,
                                         label_keys=label_keys)
        if learner_config.num_classes == 0:
            learner_config.num_classes = n_classes
        elif learner_config.num_classes != n_classes:
            raise ValueError(
                "n_classes (%d) doesn't match learner_config (%d)." %
                (n_classes, learner_config.num_classes))
        super(GradientBoostedDecisionTreeClassifier,
              self).__init__(model_fn=model.model_builder,
                             params={
                                 'head': head,
                                 'feature_columns': feature_columns,
                                 'learner_config': learner_config,
                                 'num_trees': num_trees,
                                 'weight_column_name': weight_column_name,
                                 'examples_per_layer': examples_per_layer,
                                 'center_bias': center_bias,
                                 'logits_modifier_function':
                                 logits_modifier_function,
                                 'use_core_libs': use_core_libs,
                                 'output_leaf_index': output_leaf_index,
                                 'override_global_step_value':
                                 override_global_step_value,
                                 'num_quantiles': num_quantiles,
                             },
                             model_dir=model_dir,
                             config=config,
                             feature_engineering_fn=feature_engineering_fn)
コード例 #9
0
    def __init__(self,
                 hidden_units,
                 feature_columns,
                 model_dir=None,
                 n_classes=2,
                 weight_column_name=None,
                 optimizer=None,
                 activation_fn=nn.relu,
                 dropout=None,
                 gradient_clip_norm=None,
                 enable_centered_bias=False,
                 config=None,
                 feature_engineering_fn=None,
                 embedding_lr_multipliers=None,
                 input_layer_min_slice_size=None,
                 label_keys=None):
        """Initializes a DNNClassifier instance.

    Args:
      hidden_units: List of hidden units per layer. All layers are fully
        connected. Ex. `[64, 32]` means first layer has 64 nodes and second one
        has 32.
      feature_columns: An iterable containing all the feature columns used by
        the model. All items in the set should be instances of classes derived
        from `FeatureColumn`.
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model.
      n_classes: number of label classes. Default is binary classification.
        It must be greater than 1. Note: Class labels are integers representing
        the class index (i.e. values from 0 to n_classes-1). For arbitrary
        label values (e.g. string labels), convert to class indices first.
      weight_column_name: A string defining feature column name representing
        weights. It is used to down weight or boost examples during training. It
        will be multiplied by the loss of the example.
      optimizer: An instance of `tf.Optimizer` used to train the model. If
        `None`, will use an Adagrad optimizer.
      activation_fn: Activation function applied to each layer. If `None`, will
        use tf.nn.relu. Note that a string containing the unqualified
        name of the op may also be provided, e.g., "relu", "tanh", or "sigmoid".
      dropout: When not `None`, the probability we will drop out a given
        coordinate.
      gradient_clip_norm: A float > 0. If provided, gradients are
        clipped to their global norm with this clipping ratio. See
        `tf.clip_by_global_norm` for more details.
      enable_centered_bias: A bool. If True, estimator will learn a centered
        bias variable for each class. Rest of the model structure learns the
        residual after centered bias.
      config: `RunConfig` object to configure the runtime settings.
      feature_engineering_fn: Feature engineering function. Takes features and
        labels which are the output of `input_fn` and returns features and
        labels which will be fed into the model.
      embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to
        a `float` multiplier. Multiplier will be used to multiply with learning
        rate for the embedding variables.
      input_layer_min_slice_size: Optional. The min slice size of input layer
        partitions. If not provided, will use the default of 64M.
      label_keys: Optional list of strings with size `[n_classes]` defining the
        label vocabulary. Only supported for `n_classes` > 2.

    Returns:
      A `DNNClassifier` estimator.

    Raises:
      ValueError: If `n_classes` < 2.
    """
        self._feature_columns = tuple(feature_columns or [])
        super(DNNClassifier,
              self).__init__(model_fn=_dnn_model_fn,
                             model_dir=model_dir,
                             config=config,
                             params={
                                 "head":
                                 head_lib.multi_class_head(
                                     n_classes,
                                     weight_column_name=weight_column_name,
                                     enable_centered_bias=enable_centered_bias,
                                     label_keys=label_keys),
                                 "hidden_units":
                                 hidden_units,
                                 "feature_columns":
                                 self._feature_columns,
                                 "optimizer":
                                 optimizer,
                                 "activation_fn":
                                 activation_fn,
                                 "dropout":
                                 dropout,
                                 "gradient_clip_norm":
                                 gradient_clip_norm,
                                 "embedding_lr_multipliers":
                                 embedding_lr_multipliers,
                                 "input_layer_min_slice_size":
                                 input_layer_min_slice_size,
                             },
                             feature_engineering_fn=feature_engineering_fn)
コード例 #10
0
    def __init__(
            self,  # _joint_weight pylint: disable=invalid-name
            feature_columns,
            model_dir=None,
            n_classes=2,
            weight_column_name=None,
            optimizer=None,
            gradient_clip_norm=None,
            enable_centered_bias=False,
            _joint_weight=False,
            config=None,
            feature_engineering_fn=None,
            label_keys=None):
        """Construct a `LinearClassifier` estimator object.

    Args:
      feature_columns: An iterable containing all the feature columns used by
        the model. All items in the set should be instances of classes derived
        from `FeatureColumn`.
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator
        to continue training a previously saved model.
      n_classes: number of label classes. Default is binary classification.
        Note that class labels are integers representing the class index (i.e.
        values from 0 to n_classes-1). For arbitrary label values (e.g. string
        labels), convert to class indices first.
      weight_column_name: A string defining feature column name representing
        weights. It is used to down weight or boost examples during training. It
        will be multiplied by the loss of the example.
      optimizer: The optimizer used to train the model. If specified, it should
        be either an instance of `tf.Optimizer` or the SDCAOptimizer. If `None`,
        the Ftrl optimizer will be used.
      gradient_clip_norm: A `float` > 0. If provided, gradients are clipped
        to their global norm with this clipping ratio. See
        `tf.clip_by_global_norm` for more details.
      enable_centered_bias: A bool. If True, estimator will learn a centered
        bias variable for each class. Rest of the model structure learns the
        residual after centered bias.
      _joint_weight: If True, the weights for all columns will be stored in a
        single (possibly partitioned) variable. It's more efficient, but it's
        incompatible with SDCAOptimizer, and requires all feature columns are
        sparse and use the 'sum' combiner.
      config: `RunConfig` object to configure the runtime settings.
      feature_engineering_fn: Feature engineering function. Takes features and
                        labels which are the output of `input_fn` and
                        returns features and labels which will be fed
                        into the model.
      label_keys: Optional list of strings with size `[n_classes]` defining the
        label vocabulary. Only supported for `n_classes` > 2.

    Returns:
      A `LinearClassifier` estimator.

    Raises:
      ValueError: if n_classes < 2.
      ValueError: if enable_centered_bias=True and optimizer is SDCAOptimizer.
    """
        if (isinstance(optimizer, sdca_optimizer.SDCAOptimizer)
                and enable_centered_bias):
            raise ValueError("enable_centered_bias is not supported with SDCA")

        self._feature_columns = tuple(feature_columns or [])
        assert self._feature_columns

        chief_hook = None
        head = head_lib.multi_class_head(
            n_classes,
            weight_column_name=weight_column_name,
            enable_centered_bias=enable_centered_bias,
            label_keys=label_keys)
        params = {
            "head": head,
            "feature_columns": feature_columns,
            "optimizer": optimizer,
        }

        if isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
            assert not _joint_weight, ("_joint_weight is incompatible with the"
                                       " SDCAOptimizer")
            assert n_classes == 2, "SDCA only applies to binary classification."

            model_fn = sdca_model_fn
            # The model_fn passes the model parameters to the chief_hook. We then use
            # the hook to update weights and shrink step only on the chief.
            chief_hook = _SdcaUpdateWeightsHook()
            params.update({
                "weight_column_name": weight_column_name,
                "update_weights_hook": chief_hook,
            })
        else:
            model_fn = _linear_model_fn
            params.update({
                "gradient_clip_norm": gradient_clip_norm,
                "joint_weights": _joint_weight,
            })

        super(LinearClassifier,
              self).__init__(model_fn=model_fn,
                             model_dir=model_dir,
                             config=config,
                             params=params,
                             feature_engineering_fn=feature_engineering_fn)