class AdversarialReweightingModelTest(tf.test.TestCase, absltest.TestCase):

  def setUp(self):
    super(AdversarialReweightingModelTest, self).setUp()
    self.model_dir = tempfile.mkdtemp()
    self.primary_hidden_units = [16, 4]
    self.batch_size = 8
    self.train_steps = 20
    self.test_steps = 5
    self.pretrain_steps = 5
    self.dataset_base_dir = os.path.join(os.path.dirname(__file__), 'data/toy_data')  # pylint: disable=line-too-long
    self.train_file = [os.path.join(os.path.dirname(__file__), 'data/toy_data/train.csv')]  # pylint: disable=line-too-long
    self.test_file = [os.path.join(os.path.dirname(__file__), 'data/toy_data/test.csv')]  # pylint: disable=line-too-long
    self.load_dataset = UCIAdultInput(
        dataset_base_dir=self.dataset_base_dir,
        train_file=self.train_file,
        test_file=self.test_file)
    self.target_column_name = 'income'

  def test_get_feature_columns_with_demographics(self):
    feature_columns, _, _, target_variable_column = (
        self.load_dataset.get_feature_columns(include_sensitive_columns=True))
    self.assertLen(feature_columns, 14)
    self.assertEqual(target_variable_column, self.target_column_name)

  def test_get_feature_columns_without_demographics(self):
    feature_columns, _, _, target_variable_column = self.load_dataset.get_feature_columns(include_sensitive_columns=False)  # pylint: disable=line-too-long
    self.assertLen(feature_columns, 12)
    self.assertEqual(target_variable_column, self.target_column_name)

  def test_get_input_fn(self):
    input_fn = self.load_dataset.get_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN, batch_size=self.batch_size)
    features, targets = input_fn()
    self.assertIn('sex', targets)
    self.assertIn('race', targets)
    self.assertIn('subgroup', targets)
    self.assertIn(self.target_column_name, targets)
    self.assertLen(features, 15)

  def _get_train_test_input_fn(self):
    train_input_fn = self.load_dataset.get_input_fn(
        mode=tf.estimator.ModeKeys.TRAIN, batch_size=self.batch_size)
    test_input_fn = self.load_dataset.get_input_fn(
        mode=tf.estimator.ModeKeys.EVAL, batch_size=self.batch_size)
    return train_input_fn, test_input_fn

  def test_eval_results_adversarial_reweighting_model(self):
    config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                    save_checkpoints_steps=2)
    feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
    estimator = adversarial_reweighting_model.get_estimator(
        feature_columns=feature_columns,
        label_column_name=label_column_name,
        config=config,
        model_dir=self.model_dir,
        primary_hidden_units=self.primary_hidden_units,
        batch_size=self.batch_size,
        pretrain_steps=self.pretrain_steps)
    self.assertIsInstance(estimator, tf.estimator.Estimator)
    train_input_fn, test_input_fn = self._get_train_test_input_fn()
    estimator.train(input_fn=train_input_fn, steps=self.train_steps)
    eval_results = estimator.evaluate(input_fn=test_input_fn,
                                      steps=self.test_steps)
    self.assertNotEmpty(eval_results)
    # # Checks if all tp,tn,fp,fn keys are present in eval_results dictionary
    self.assertIn('auc', eval_results)
    self.assertIn('fp', eval_results)
    self.assertIn('fn', eval_results)
    self.assertIn('tp', eval_results)
    self.assertIn('tn', eval_results)

  def test_global_steps_adversarial_reweighting_model(self):
    config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                    save_checkpoints_steps=2)
    feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
    estimator = adversarial_reweighting_model.get_estimator(
        feature_columns=feature_columns,
        label_column_name=label_column_name,
        config=config,
        model_dir=self.model_dir,
        primary_hidden_units=self.primary_hidden_units,
        batch_size=self.batch_size,
        pretrain_steps=self.pretrain_steps)
    self.assertIsInstance(estimator, tf.estimator.Estimator)
    train_input_fn, test_input_fn = self._get_train_test_input_fn()
    estimator.train(input_fn=train_input_fn, steps=self.train_steps)
    eval_results = estimator.evaluate(input_fn=test_input_fn,
                                      steps=self.test_steps)
    # Checks if global step has reached specified number of train_steps
    # # As a artifact of the way train_ops is defined in
    # _AdversarialReweightingEstimator.
    # # Training stops two steps after the specified number of train_steps.
    self.assertIn('global_step', eval_results)
    self.assertEqual(eval_results['global_step'], self.train_steps+2)

  def test_create_adversarial_reweighting_estimator_with_demographics(self):
    config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                    save_checkpoints_steps=2)
    feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
    estimator = adversarial_reweighting_model.get_estimator(
        feature_columns=feature_columns,
        label_column_name=label_column_name,
        config=config,
        model_dir=self.model_dir,
        primary_hidden_units=self.primary_hidden_units,
        batch_size=self.batch_size,
        pretrain_steps=self.pretrain_steps,
        primary_learning_rate=0.01,
        adversary_learning_rate=0.01,
        optimizer='Adagrad',
        activation=tf.nn.relu,
        adversary_loss_type='ce_loss',
        adversary_include_label=True,
        upweight_positive_instance_only=False)
    self.assertIsInstance(estimator, tf.estimator.Estimator)

  def test_create_adversarial_reweighting_estimator_with_hinge_loss(self):
    config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                    save_checkpoints_steps=2)
    feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
    estimator = adversarial_reweighting_model.get_estimator(
        feature_columns=feature_columns,
        label_column_name=label_column_name,
        config=config,
        model_dir=self.model_dir,
        primary_hidden_units=self.primary_hidden_units,
        batch_size=self.batch_size,
        pretrain_steps=self.pretrain_steps,
        primary_learning_rate=0.01,
        adversary_learning_rate=0.01,
        optimizer='Adagrad',
        activation=tf.nn.relu,
        adversary_loss_type='hinge_loss',
        adversary_include_label=True,
        upweight_positive_instance_only=False)
    self.assertIsInstance(estimator, tf.estimator.Estimator)

  def test_create_adversarial_reweighting_estimator_with_crossentropy_loss(
      self):
    config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                    save_checkpoints_steps=2)
    feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
    estimator = adversarial_reweighting_model.get_estimator(
        feature_columns=feature_columns,
        label_column_name=label_column_name,
        config=config,
        model_dir=self.model_dir,
        primary_hidden_units=self.primary_hidden_units,
        batch_size=self.batch_size,
        pretrain_steps=self.pretrain_steps,
        primary_learning_rate=0.01,
        adversary_learning_rate=0.01,
        optimizer='Adagrad',
        activation=tf.nn.relu,
        adversary_loss_type='ce_loss',
        adversary_include_label=True,
        upweight_positive_instance_only=False)
    self.assertIsInstance(estimator, tf.estimator.Estimator)

  def test_create_adversarial_reweighting_estimator_without_label(self):
    config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                    save_checkpoints_steps=2)
    feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
    estimator = adversarial_reweighting_model.get_estimator(
        feature_columns=feature_columns,
        label_column_name=label_column_name,
        config=config,
        model_dir=self.model_dir,
        primary_hidden_units=self.primary_hidden_units,
        batch_size=self.batch_size,
        pretrain_steps=self.pretrain_steps,
        primary_learning_rate=0.01,
        adversary_learning_rate=0.01,
        optimizer='Adagrad',
        activation=tf.nn.relu,
        adversary_loss_type='ce_loss',
        adversary_include_label=False,
        upweight_positive_instance_only=False)
    self.assertIsInstance(estimator, tf.estimator.Estimator)

  def test_create_adversarial_reweighting_estimator_without_demographics(self):
    config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                    save_checkpoints_steps=2)
    feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=False)  # pylint: disable=line-too-long
    estimator = adversarial_reweighting_model.get_estimator(
        feature_columns=feature_columns,
        label_column_name=label_column_name,
        config=config,
        model_dir=self.model_dir,
        primary_hidden_units=self.primary_hidden_units,
        batch_size=self.batch_size,
        pretrain_steps=self.pretrain_steps,
        primary_learning_rate=0.01,
        adversary_learning_rate=0.01,
        optimizer='Adagrad',
        activation=tf.nn.relu,
        adversary_loss_type='ce_loss',
        adversary_include_label=True,
        upweight_positive_instance_only=False)
    self.assertIsInstance(estimator, tf.estimator.Estimator)
Exemplo n.º 2
0
class FairnessMetricsTest(tf.test.TestCase, absltest.TestCase):
    def setUp(self):
        super(FairnessMetricsTest, self).setUp()
        self.num_thresholds = 5
        self.label_column_name = 'income'
        self.protected_groups = ['sex', 'race']
        self.subgroups = [0, 1, 2, 3]
        self.model_dir = tempfile.mkdtemp()
        self.print_dir = tempfile.mkdtemp()
        self.primary_hidden_units = [16, 4]
        self.batch_size = 8
        self.train_steps = 10
        self.test_steps = 5
        self.pretrain_steps = 5
        self.dataset_base_dir = os.path.join(os.path.dirname(__file__), 'data/toy_data')  # pylint: disable=line-too-long
        self.train_file = [os.path.join(os.path.dirname(__file__), 'data/toy_data/train.csv')]  # pylint: disable=line-too-long
        self.test_file = [os.path.join(os.path.dirname(__file__), 'data/toy_data/test.csv')]  # pylint: disable=line-too-long
        self.load_dataset = UCIAdultInput(
            dataset_base_dir=self.dataset_base_dir,
            train_file=self.train_file,
            test_file=self.test_file)
        self.fairness_metrics = RobustFairnessMetrics(
            label_column_name=self.label_column_name,
            protected_groups=self.protected_groups,
            subgroups=self.subgroups)
        self.eval_metric_keys = [
            'accuracy', 'recall', 'precision', 'tp', 'tn', 'fp', 'fn', 'fpr',
            'fnr'
        ]

    def _get_train_test_input_fn(self):
        train_input_fn = self.load_dataset.get_input_fn(
            mode=tf.estimator.ModeKeys.TRAIN, batch_size=self.batch_size)
        test_input_fn = self.load_dataset.get_input_fn(
            mode=tf.estimator.ModeKeys.EVAL, batch_size=self.batch_size)
        return train_input_fn, test_input_fn

    def _get_estimator(self):
        config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                        save_checkpoints_steps=1)
        feature_columns, _, _, label_column_name = (
            self.load_dataset.get_feature_columns(
                include_sensitive_columns=True))
        estimator = adversarial_reweighting_model.get_estimator(
            feature_columns=feature_columns,
            label_column_name=label_column_name,
            config=config,
            model_dir=self.model_dir,
            primary_hidden_units=self.primary_hidden_units,
            batch_size=self.batch_size,
            pretrain_steps=self.pretrain_steps,
            primary_learning_rate=0.01,
            adversary_learning_rate=0.01,
            optimizer='Adagrad',
            activation=tf.nn.relu,
            adversary_loss_type='ce_loss',
            adversary_include_label=True,
            upweight_positive_instance_only=False)
        return estimator

    def test_create_and_add_fairness_metrics(self):
        # Instantiates a robust estimator
        estimator = self._get_estimator()
        self.assertIsInstance(estimator, tf.estimator.Estimator)

        # Adds additional fairness metrics to estimator
        eval_metrics_fn = self.fairness_metrics.create_fairness_metrics_fn(
            num_thresholds=self.num_thresholds)
        estimator = tf.estimator.add_metrics(estimator, eval_metrics_fn)

        # Trains and evaluated robust model
        train_input_fn, test_input_fn = self._get_train_test_input_fn()
        estimator.train(input_fn=train_input_fn, steps=self.train_steps)
        eval_results = estimator.evaluate(input_fn=test_input_fn,
                                          steps=self.test_steps)

        # Checks if eval_results are computed
        self.assertNotEmpty(eval_results)

        for key in self.eval_metric_keys:
            self.assertIn(key, eval_results)

    def test_create_and_add_fairness_metrics_with_print_dir(self):
        # Instantiates a robust estimator
        estimator = self._get_estimator()
        self.assertIsInstance(estimator, tf.estimator.Estimator)

        # Adds additional fairness metrics to estimator
        self.fairness_metrics_with_print = RobustFairnessMetrics(
            label_column_name=self.label_column_name,
            protected_groups=self.protected_groups,
            subgroups=self.subgroups,
            print_dir=self.print_dir)
        eval_metrics_fn = self.fairness_metrics.create_fairness_metrics_fn(
            num_thresholds=self.num_thresholds)
        estimator = tf.estimator.add_metrics(estimator, eval_metrics_fn)

        # Trains and evaluated robust model
        train_input_fn, test_input_fn = self._get_train_test_input_fn()
        estimator.train(input_fn=train_input_fn, steps=self.train_steps)
        eval_results = estimator.evaluate(input_fn=test_input_fn,
                                          steps=self.test_steps)

        # Checks if eval_results are computed
        self.assertNotEmpty(eval_results)
        for key in self.eval_metric_keys:
            self.assertIn(key, eval_results)

    def test_subgroup_metrics(self):

        # Instantiates a robust estimator
        estimator = self._get_estimator()
        self.assertIsInstance(estimator, tf.estimator.Estimator)

        # Adds additional fairness metrics to estimator
        eval_metrics_fn = self.fairness_metrics.create_fairness_metrics_fn(
            num_thresholds=self.num_thresholds)
        estimator = tf.estimator.add_metrics(estimator, eval_metrics_fn)

        # Trains and evaluated robust model
        train_input_fn, test_input_fn = self._get_train_test_input_fn()
        estimator.train(input_fn=train_input_fn, steps=self.train_steps)
        eval_results = estimator.evaluate(input_fn=test_input_fn,
                                          steps=self.test_steps)

        # Checks if eval_results are computed
        self.assertNotEmpty(eval_results)

        # # Checks if auc metric is computed for all subgroups
        for subgroup in self.subgroups:
            self.assertIn('auc subgroup {}'.format(subgroup), eval_results)
            self.assertIn('fpr subgroup {}'.format(subgroup), eval_results)
            self.assertIn('fnr subgroup {}'.format(subgroup), eval_results)

    def test_protected_group_metrics(self):

        # Instantiates a robust estimator
        estimator = self._get_estimator()
        self.assertIsInstance(estimator, tf.estimator.Estimator)

        # Adds additional fairness metrics to estimator
        eval_metrics_fn = self.fairness_metrics.create_fairness_metrics_fn(
            num_thresholds=self.num_thresholds)
        estimator = tf.estimator.add_metrics(estimator, eval_metrics_fn)

        # Trains and evaluated robust model
        train_input_fn, test_input_fn = self._get_train_test_input_fn()
        estimator.train(input_fn=train_input_fn, steps=self.train_steps)
        eval_results = estimator.evaluate(input_fn=test_input_fn,
                                          steps=self.test_steps)

        # Checks if eval_results are computed
        self.assertNotEmpty(eval_results)

        # # Checks if auc metric is computed for all protected_groups
        for group in self.protected_groups:
            self.assertIn('auc {} group 0'.format(group), eval_results)
            self.assertIn('auc {} group 1'.format(group), eval_results)

    def test_threshold_metrics(self):

        # Instantiates a robust estimator
        estimator = self._get_estimator()
        self.assertIsInstance(estimator, tf.estimator.Estimator)

        # Adds additional fairness metrics to estimator
        eval_metrics_fn = self.fairness_metrics.create_fairness_metrics_fn(
            num_thresholds=self.num_thresholds)
        estimator = tf.estimator.add_metrics(estimator, eval_metrics_fn)

        # Trains and evaluated robust model
        train_input_fn, test_input_fn = self._get_train_test_input_fn()
        estimator.train(input_fn=train_input_fn, steps=self.train_steps)
        eval_results = estimator.evaluate(input_fn=test_input_fn,
                                          steps=self.test_steps)

        # # Checks if tp,tn,fp,fn metrics are computed at thresholds
        self.assertIn('fp_th', eval_results)
        self.assertIn('fn_th', eval_results)
        self.assertIn('tp_th', eval_results)
        self.assertIn('tn_th', eval_results)

        # # Checks if the len of tp_th matches self.num_thresholds
        self.assertLen(eval_results['tp_th'], self.num_thresholds)

        # # Checks if threshold metrics are computed for protected_groups
        self.assertIn('fp_th subgroup {}'.format(self.subgroups[0]),
                      eval_results)
        self.assertIn('fp_th {} group 0'.format(self.protected_groups[0]),
                      eval_results)
        self.assertIn('fp_th {} group 1'.format(self.protected_groups[0]),
                      eval_results)
class AdversarialSubgroupReweightingModelTest(tf.test.TestCase, absltest.TestCase):  # pylint: disable=line-too-long
    def setUp(self):
        super(AdversarialSubgroupReweightingModelTest, self).setUp()
        self.model_dir = tempfile.mkdtemp()
        self.primary_hidden_units = [16, 4]
        self.adversary_hidden_units = [4]
        self.batch_size = 8
        self.train_steps = 20
        self.test_steps = 5
        self.dataset_base_dir = os.path.join(os.path.dirname(__file__), 'data/toy_data')  # pylint: disable=line-too-long
        self.train_file = [os.path.join(os.path.dirname(__file__), 'data/toy_data/train.csv')]  # pylint: disable=line-too-long
        self.test_file = [os.path.join(os.path.dirname(__file__), 'data/toy_data/test.csv')]  # pylint: disable=line-too-long
        self.load_dataset = UCIAdultInput(
            dataset_base_dir=self.dataset_base_dir,
            train_file=self.train_file,
            test_file=self.test_file)
        self.label_column_name = 'income'
        self.protected_groups = ['sex', 'race']
        self.subgroups = [0, 1, 2, 3]
        self.fairness_metrics = RobustFairnessMetrics(
            label_column_name=self.label_column_name,
            protected_groups=self.protected_groups,
            subgroups=self.subgroups)

    def test_get_feature_columns_with_demographics(self):
        feature_columns, _, _, target_variable_column = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
        self.assertLen(feature_columns, 14)
        self.assertEqual(target_variable_column, self.label_column_name)

    def test_get_feature_columns_without_demographics(self):
        feature_columns, _, _, target_variable_column = self.load_dataset.get_feature_columns(include_sensitive_columns=False)  # pylint: disable=line-too-long
        self.assertLen(feature_columns, 12)
        self.assertEqual(target_variable_column, self.label_column_name)

    def test_get_input_fn(self):
        input_fn = self.load_dataset.get_input_fn(
            mode=tf.estimator.ModeKeys.TRAIN, batch_size=self.batch_size)
        features, targets = input_fn()
        self.assertIn('sex', targets)
        self.assertIn('race', targets)
        self.assertIn('subgroup', targets)
        self.assertIn(self.label_column_name, targets)
        self.assertLen(features, 15)

    def _get_train_test_input_fn(self):
        train_input_fn = self.load_dataset.get_input_fn(
            mode=tf.estimator.ModeKeys.TRAIN, batch_size=self.batch_size)
        test_input_fn = self.load_dataset.get_input_fn(
            mode=tf.estimator.ModeKeys.EVAL, batch_size=self.batch_size)
        return train_input_fn, test_input_fn

    def test_no_protected_columns_added_in_features(self):
        """Shoud raise ValueError if protected_column are not in feature columns."""
        with self.assertRaises(ValueError):
            config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                            save_checkpoints_steps=2)
            feature_columns, _, _, label_column_name = (
                self.load_dataset.get_feature_columns(
                    include_sensitive_columns=False))
            _ = adversarial_subgroup_reweighting_model.get_estimator(
                feature_columns=feature_columns,
                label_column_name=label_column_name,
                protected_column_names=self.protected_groups,
                config=config,
                model_dir=self.model_dir,
                primary_hidden_units=self.primary_hidden_units,
                adversary_hidden_units=self.adversary_hidden_units,
                batch_size=self.batch_size,
                primary_learning_rate=0.01,
                optimizer='Adagrad',
                activation=tf.nn.relu)

    def test_eval_results_model(self):
        config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                        save_checkpoints_steps=2)
        feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
        estimator = adversarial_subgroup_reweighting_model.get_estimator(
            feature_columns=feature_columns,
            label_column_name=label_column_name,
            protected_column_names=self.protected_groups,
            config=config,
            model_dir=self.model_dir,
            primary_hidden_units=self.primary_hidden_units,
            adversary_hidden_units=self.adversary_hidden_units,
            batch_size=self.batch_size)
        self.assertIsInstance(estimator, tf.estimator.Estimator)
        train_input_fn, test_input_fn = self._get_train_test_input_fn()
        estimator.train(input_fn=train_input_fn, steps=self.train_steps)
        eval_results = estimator.evaluate(input_fn=test_input_fn,
                                          steps=self.test_steps)
        self.assertNotEmpty(eval_results)
        # # Checks if all tp,tn,fp,fn keys are present in eval_results dictionary
        self.assertIn('auc', eval_results)
        self.assertIn('fp', eval_results)
        self.assertIn('fn', eval_results)
        self.assertIn('tp', eval_results)
        self.assertIn('tn', eval_results)

    def test_add_fairness_metrics_model(self):
        config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                        save_checkpoints_steps=2)
        feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
        estimator = adversarial_subgroup_reweighting_model.get_estimator(
            feature_columns=feature_columns,
            label_column_name=label_column_name,
            protected_column_names=self.protected_groups,
            config=config,
            model_dir=self.model_dir,
            primary_hidden_units=self.primary_hidden_units,
            adversary_hidden_units=self.adversary_hidden_units,
            batch_size=self.batch_size)
        self.assertIsInstance(estimator, tf.estimator.Estimator)

        # Adds additional fairness metrics to estimator
        eval_metrics_fn = self.fairness_metrics.create_fairness_metrics_fn()
        estimator = tf.estimator.add_metrics(estimator, eval_metrics_fn)

        train_input_fn, test_input_fn = self._get_train_test_input_fn()
        estimator.train(input_fn=train_input_fn, steps=self.train_steps)
        eval_results = estimator.evaluate(input_fn=test_input_fn,
                                          steps=self.test_steps)
        self.assertNotEmpty(eval_results)
        # # Checks if auc metric is computed for all subgroups
        for subgroup in self.subgroups:
            self.assertIn('auc subgroup {}'.format(subgroup), eval_results)

    def test_global_steps_model(self):
        config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                        save_checkpoints_steps=2)
        feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
        estimator = adversarial_subgroup_reweighting_model.get_estimator(
            feature_columns=feature_columns,
            label_column_name=label_column_name,
            protected_column_names=self.protected_groups,
            config=config,
            model_dir=self.model_dir,
            primary_hidden_units=self.primary_hidden_units,
            adversary_hidden_units=self.adversary_hidden_units,
            batch_size=self.batch_size)
        self.assertIsInstance(estimator, tf.estimator.Estimator)
        train_input_fn, test_input_fn = self._get_train_test_input_fn()
        estimator.train(input_fn=train_input_fn, steps=self.train_steps)
        eval_results = estimator.evaluate(input_fn=test_input_fn,
                                          steps=self.test_steps)
        # Checks if global step has reached specified number of train_steps
        # # As a artifact of the way alternate train_ops is defined, training stops
        # # two steps after the specified number of train_steps.
        self.assertIn('global_step', eval_results)
        self.assertEqual(eval_results['global_step'], self.train_steps + 2)

    def test_create_estimator_with_demographics(self):
        config = tf.estimator.RunConfig(model_dir=self.model_dir,
                                        save_checkpoints_steps=2)
        feature_columns, _, _, label_column_name = self.load_dataset.get_feature_columns(include_sensitive_columns=True)  # pylint: disable=line-too-long
        estimator = adversarial_subgroup_reweighting_model.get_estimator(
            feature_columns=feature_columns,
            label_column_name=label_column_name,
            protected_column_names=self.protected_groups,
            config=config,
            model_dir=self.model_dir,
            primary_hidden_units=self.primary_hidden_units,
            adversary_hidden_units=self.adversary_hidden_units,
            batch_size=self.batch_size,
            primary_learning_rate=0.01,
            optimizer='Adagrad',
            activation=tf.nn.relu)
        self.assertIsInstance(estimator, tf.estimator.Estimator)