Exemple #1
0
    def testFitAndEvaluateMultiClassFullDontThrowException(self):
        n_classes = 3
        learner_config = learner_pb2.LearnerConfig()
        learner_config.num_classes = n_classes
        learner_config.constraints.max_tree_depth = 1
        learner_config.multi_class_strategy = (
            learner_pb2.LearnerConfig.FULL_HESSIAN)

        head_fn = estimator.core_multiclass_head(n_classes=n_classes)

        model_dir = tempfile.mkdtemp()
        config = run_config.RunConfig()

        classifier = estimator.CoreGradientBoostedDecisionTreeEstimator(
            learner_config=learner_config,
            head=head_fn,
            num_trees=1,
            center_bias=False,
            examples_per_layer=7,
            model_dir=model_dir,
            config=config,
            feature_columns=[core_feature_column.numeric_column("x")])

        classifier.train(input_fn=_multiclass_train_input_fn, steps=100)
        classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1)
        classifier.predict(input_fn=_eval_input_fn)
  def testFitAndEvaluateMultiClassFullDontThrowException(self):
    n_classes = 3
    learner_config = learner_pb2.LearnerConfig()
    learner_config.num_classes = n_classes
    learner_config.constraints.max_tree_depth = 1
    learner_config.multi_class_strategy = (
        learner_pb2.LearnerConfig.FULL_HESSIAN)

    head_fn = estimator.core_multiclass_head(n_classes=n_classes)

    model_dir = tempfile.mkdtemp()
    config = run_config.RunConfig()

    classifier = estimator.CoreGradientBoostedDecisionTreeEstimator(
        learner_config=learner_config,
        head=head_fn,
        num_trees=1,
        center_bias=False,
        examples_per_layer=7,
        model_dir=model_dir,
        config=config,
        feature_columns=[core_feature_column.numeric_column("x")])

    classifier.train(input_fn=_multiclass_train_input_fn, steps=100)
    classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1)
    classifier.predict(input_fn=_eval_input_fn)