def testFitAndEvaluateMultiClassFullDontThrowException(self):
        learner_config = learner_pb2.LearnerConfig()
        learner_config.num_classes = 3
        learner_config.constraints.max_tree_depth = 1
        learner_config.multi_class_strategy = (
            learner_pb2.LearnerConfig.FULL_HESSIAN)

        model_dir = tempfile.mkdtemp()
        config = run_config.RunConfig()

        classifier = estimator.GradientBoostedDecisionTreeClassifier(
            learner_config=learner_config,
            n_classes=learner_config.num_classes,
            num_trees=1,
            examples_per_layer=7,
            model_dir=model_dir,
            config=config,
            center_bias=False,
            feature_columns=[contrib_feature_column.real_valued_column("x")])

        classifier.fit(input_fn=_multiclass_train_input_fn, steps=100)
        classifier.evaluate(input_fn=_eval_input_fn, steps=1)
        classifier.export(self._export_dir_base)
        result_iter = classifier.predict(input_fn=_eval_input_fn)
        for prediction_dict in result_iter:
            self.assertTrue("classes" in prediction_dict)
    def testFitAndEvaluateDontThrowException(self):
        learner_config = learner_pb2.LearnerConfig()
        learner_config.num_classes = 2
        learner_config.constraints.max_tree_depth = 1
        model_dir = tempfile.mkdtemp()
        config = run_config.RunConfig()

        classifier = estimator.GradientBoostedDecisionTreeClassifier(
            learner_config=learner_config,
            num_trees=1,
            examples_per_layer=3,
            model_dir=model_dir,
            config=config,
            feature_columns=[contrib_feature_column.real_valued_column("x")])

        classifier.fit(input_fn=_train_input_fn, steps=15)
        classifier.evaluate(input_fn=_eval_input_fn, steps=1)
        classifier.export(self._export_dir_base)
    def testOverridesGlobalSteps(self):
        learner_config = learner_pb2.LearnerConfig()
        learner_config.num_classes = 2
        learner_config.constraints.max_tree_depth = 2
        model_dir = tempfile.mkdtemp()
        config = run_config.RunConfig()

        classifier = estimator.GradientBoostedDecisionTreeClassifier(
            learner_config=learner_config,
            num_trees=1,
            examples_per_layer=3,
            model_dir=model_dir,
            config=config,
            feature_columns=[contrib_feature_column.real_valued_column("x")],
            output_leaf_index=False,
            override_global_step_value=10000000)

        classifier.fit(input_fn=_train_input_fn, steps=15)
        self._assert_checkpoint(classifier.model_dir, global_step=10000000)
    def testThatLeafIndexIsInPredictions(self):
        learner_config = learner_pb2.LearnerConfig()
        learner_config.num_classes = 2
        learner_config.constraints.max_tree_depth = 1
        model_dir = tempfile.mkdtemp()
        config = run_config.RunConfig()

        classifier = estimator.GradientBoostedDecisionTreeClassifier(
            learner_config=learner_config,
            num_trees=1,
            examples_per_layer=3,
            model_dir=model_dir,
            config=config,
            feature_columns=[contrib_feature_column.real_valued_column("x")],
            output_leaf_index=True)

        classifier.fit(input_fn=_train_input_fn, steps=15)
        result_iter = classifier.predict(input_fn=_eval_input_fn)
        for prediction_dict in result_iter:
            self.assertTrue("leaf_index" in prediction_dict)
            self.assertTrue("logits" in prediction_dict)
    def testForcedInitialSplits(self):
        learner_config = learner_pb2.LearnerConfig()
        learner_config.num_classes = 2
        learner_config.constraints.max_tree_depth = 3

        initial_subtree = """
            nodes {
              dense_float_binary_split {
                feature_column: 0
                threshold: -0.5
                left_id: 1
                right_id: 2
              }
              node_metadata {
                gain: 0
              }
            }
            nodes {
              dense_float_binary_split {
                feature_column: 1
                threshold: 0.52
                left_id: 3
                right_id: 4
              }
              node_metadata {
                gain: 0
              }
            }
            nodes {
              dense_float_binary_split {
                feature_column: 1
                threshold: 0.554
                left_id: 5
                right_id: 6
              }
              node_metadata {
                gain: 0
              }
            }
            nodes {
              leaf {
                vector {
                  value: 0.0
                }
              }
            }
            nodes {
              leaf {
                vector {
                  value: 0.0
                }
              }
            }
            nodes {
              leaf {
                vector {
                  value: 0.0
                }
              }
            }
            nodes {
              leaf {
                vector {
                  value: 0.0
                }
              }
            }
    """
        tree_proto = tree_config_pb2.DecisionTreeConfig()
        text_format.Merge(initial_subtree, tree_proto)

        # Set initial subtree info.
        learner_config.each_tree_start.CopyFrom(tree_proto)
        learner_config.each_tree_start_num_layers = 2

        model_dir = tempfile.mkdtemp()
        config = run_config.RunConfig()

        classifier = estimator.GradientBoostedDecisionTreeClassifier(
            learner_config=learner_config,
            num_trees=2,
            examples_per_layer=6,
            model_dir=model_dir,
            config=config,
            center_bias=False,
            feature_columns=[contrib_feature_column.real_valued_column("x")],
            output_leaf_index=False)

        classifier.fit(input_fn=_train_input_fn, steps=100)
        # When no override of global steps, 5 steps were used.
        ensemble = self._assert_checkpoint_and_return_model(
            classifier.model_dir, global_step=6)

        # TODO(nponomareva): find a better way to test this.
        expected_ensemble = """
      trees {
        nodes {
          dense_float_binary_split {
            threshold: -0.5
            left_id: 1
            right_id: 2
          }
          node_metadata {
          }
        }
        nodes {
          dense_float_binary_split {
            feature_column: 1
            threshold: 0.52
            left_id: 3
            right_id: 4
          }
          node_metadata {
          }
        }
        nodes {
          dense_float_binary_split {
            feature_column: 1
            threshold: 0.554
            left_id: 5
            right_id: 6
          }
          node_metadata {
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
        nodes {
          dense_float_binary_split {
            threshold: 1.0
            left_id: 7
            right_id: 8
          }
          node_metadata {
            gain: 0.888888895512
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
        nodes {
          leaf {
            vector {
              value: -2.0
            }
          }
        }
        nodes {
          leaf {
            vector {
              value: 2.00000023842
            }
          }
        }
      }
      trees {
        nodes {
          dense_float_binary_split {
            threshold: -0.5
            left_id: 1
            right_id: 2
          }
          node_metadata {
          }
        }
        nodes {
          dense_float_binary_split {
            feature_column: 1
            threshold: 0.52
            left_id: 3
            right_id: 4
          }
          node_metadata {
          }
        }
        nodes {
          dense_float_binary_split {
            feature_column: 1
            threshold: 0.554
            left_id: 5
            right_id: 6
          }
          node_metadata {
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
        nodes {
          dense_float_binary_split {
            threshold: 1.0
            left_id: 7
            right_id: 8
          }
          node_metadata {
            gain: 0.727760672569
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
        nodes {
          leaf {
            vector {
              value: -1.81873059273
            }
          }
        }
        nodes {
          leaf {
            vector {
              value: 1.81873047352
            }
          }
        }
      }
      trees {
        nodes {
          dense_float_binary_split {
            threshold: -0.5
            left_id: 1
            right_id: 2
          }
          node_metadata {
          }
        }
        nodes {
          dense_float_binary_split {
            feature_column: 1
            threshold: 0.52
            left_id: 3
            right_id: 4
          }
          node_metadata {
          }
        }
        nodes {
          dense_float_binary_split {
            feature_column: 1
            threshold: 0.554
            left_id: 5
            right_id: 6
          }
          node_metadata {
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
        nodes {
          leaf {
            vector {
              value: 0.0
            }
          }
        }
      }
      tree_weights: 0.10000000149
      tree_weights: 0.10000000149
      tree_weights: 0.10000000149
      tree_metadata {
        num_tree_weight_updates: 1
        num_layers_grown: 3
        is_finalized: true
      }
      tree_metadata {
        num_tree_weight_updates: 1
        num_layers_grown: 3
        is_finalized: true
      }
      tree_metadata {
        num_tree_weight_updates: 1
        num_layers_grown: 2
      }
      growing_metadata {
        num_layers_attempted: 3
      }
    """
        self.assertProtoEquals(expected_ensemble, ensemble)