def testTrainClassifierNonInMemory(self):
    ops.reset_default_graph()
    expected_first, expected_second, expected_third = (
        self._get_expected_ensembles_for_classification())
    with self.test_session() as sess:
      # Train without train_in_memory mode.
      with sess.graph.as_default():
        train_op, ensemble_serialized = self._get_train_op_and_ensemble(
            boosted_trees._create_classification_head(n_classes=2),
            run_config.RunConfig(),
            is_classification=True,
            train_in_memory=False)
      _, serialized = sess.run([train_op, ensemble_serialized])
      # Validate the trained ensemble.
      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
      ensemble_proto.ParseFromString(serialized)
      self.assertProtoEquals(expected_first, ensemble_proto)

      # Run one more time and validate the trained ensemble.
      _, serialized = sess.run([train_op, ensemble_serialized])
      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
      ensemble_proto.ParseFromString(serialized)
      self.assertProtoEquals(expected_second, ensemble_proto)

      # Third round training and validation.
      _, serialized = sess.run([train_op, ensemble_serialized])
      ensemble_proto = boosted_trees_pb2.TreeEnsemble()
      ensemble_proto.ParseFromString(serialized)
      self.assertProtoEquals(expected_third, ensemble_proto)
    def testTrainClassifierNonInMemory(self):
        ops.reset_default_graph()
        expected_first, expected_second, expected_third = (
            self._get_expected_ensembles_for_classification())
        with self.test_session() as sess:
            # Train without train_in_memory mode.
            with sess.graph.as_default():
                train_op, ensemble_serialized = self._get_train_op_and_ensemble(
                    boosted_trees._create_classification_head(n_classes=2),
                    run_config.RunConfig(),
                    is_classification=True,
                    train_in_memory=False)
            _, serialized = sess.run([train_op, ensemble_serialized])
            # Validate the trained ensemble.
            ensemble_proto = boosted_trees_pb2.TreeEnsemble()
            ensemble_proto.ParseFromString(serialized)
            self.assertProtoEquals(expected_first, ensemble_proto)

            # Run one more time and validate the trained ensemble.
            _, serialized = sess.run([train_op, ensemble_serialized])
            ensemble_proto = boosted_trees_pb2.TreeEnsemble()
            ensemble_proto.ParseFromString(serialized)
            self.assertProtoEquals(expected_second, ensemble_proto)

            # Third round training and validation.
            _, serialized = sess.run([train_op, ensemble_serialized])
            ensemble_proto = boosted_trees_pb2.TreeEnsemble()
            ensemble_proto.ParseFromString(serialized)
            self.assertProtoEquals(expected_third, ensemble_proto)