def testCustomBinningScore(self):
        y_true = np.array([1., 0., 0., 1.])
        y_pred = np.array([0.31, 0.32, 0.83, 0.64])

        oracle_auc = metrics.OracleCollaborativeAUC(
            oracle_fraction=0.5,  # 2 examples sent to oracle
            num_bins=4,  # (-inf, 0.25), [0.25, 0.5), [0.5, 0.75), [0.75, inf)
            num_thresholds=4,  # -1e-7, 0.33, 0.67, 1.0000001
        )

        # This custom_binning_score means 0.31 and 0.32 are always sent to oracle.
        result = oracle_auc(y_true, y_pred, custom_binning_score=y_pred)

        self.assertAllClose(
            oracle_auc.binned_true_positives,
            # y_true's positives are 0.31 and 0.64 in y_pred.
            np.array([
                [0., 1., 1., 0.],
                [0., 0., 1., 0.],  # 0.31 is no longer above threshold 0.33
                [0., 0., 0., 0.],  # 0.64 is below threshold 0.67
                [0., 0., 0., 0.],
            ]))
        self.assertAllClose(
            oracle_auc.binned_true_negatives,
            # The possible true negatives are 0.32 and 0.83.
            np.array([
                [0., 0., 0., 0.],
                [0., 1., 0., 0.],  # 0.32 is below threshold 0.33
                [0., 1., 0., 0.],  # 0.84 is still above threshold 0.67
                [0., 1., 0., 1.],
            ]))
        self.assertAllClose(
            oracle_auc.binned_false_positives,
            # Compare these values with oracle_auc.binned_true_negatives.
            # For example, the total across their rows must always be 2.
            np.array([
                [0., 1., 0.,
                 1.],  # 0.32 and 0.84 are both above threshold -1e-7
                [0., 0., 0., 1.],  # 0.32 moves to true_negatives
                [0., 0., 0., 1.],  # 0.84 still above threshold
                [0., 0., 0., 0.],  # all examples moved to true_negatives
            ]))
        self.assertAllClose(
            oracle_auc.binned_false_negatives,
            # Compare these values with oracle_auc.binned_true_positives.
            np.array([
                [0., 0., 0., 0.],
                [0., 1., 0.,
                 0.],  # 0.31 becomes a false negative at threshold 0.33
                [0., 1., 1.,
                 0.],  # 0.64 becomes a false negative at threshold 0.67
                [0., 1., 1., 0.],
            ]))

        # 0.31 is always corrected from false_positives to true_negatives.
        self.assertAllClose(oracle_auc.true_positives,
                            np.array([2., 2., 1., 0.]))
        self.assertAllClose(oracle_auc.true_negatives,
                            np.array([0., 1., 1., 2.]))
        self.assertAllClose(oracle_auc.false_positives,
                            np.array([2., 1., 1., 0.]))
        self.assertAllClose(oracle_auc.false_negatives,
                            np.array([0., 0., 1., 2.]))

        self.assertEqual(result, 0.625)
Exemple #2
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Model checkpoint will be saved at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    test_batch_size = batch_size
    data_buffer_size = batch_size * 10

    train_dataset_builder = ds.WikipediaToxicityDataset(
        split='train',
        data_dir=FLAGS.in_dataset_dir,
        shuffle_buffer_size=data_buffer_size)
    ind_dataset_builder = ds.WikipediaToxicityDataset(
        split='test',
        data_dir=FLAGS.in_dataset_dir,
        shuffle_buffer_size=data_buffer_size)
    ood_dataset_builder = ds.CivilCommentsDataset(
        split='test',
        data_dir=FLAGS.ood_dataset_dir,
        shuffle_buffer_size=data_buffer_size)
    ood_identity_dataset_builder = ds.CivilCommentsIdentitiesDataset(
        split='test',
        data_dir=FLAGS.identity_dataset_dir,
        shuffle_buffer_size=data_buffer_size)

    train_dataset_builders = {
        'wikipedia_toxicity_subtypes': train_dataset_builder
    }
    test_dataset_builders = {
        'ind': ind_dataset_builder,
        'ood': ood_dataset_builder,
        'ood_identity': ood_identity_dataset_builder,
    }
    if FLAGS.prediction_mode and FLAGS.identity_prediction:
        for dataset_name in utils.IDENTITY_LABELS:
            if utils.NUM_EXAMPLES[dataset_name]['test'] > 100:
                test_dataset_builders[
                    dataset_name] = ds.CivilCommentsIdentitiesDataset(
                        split='test',
                        data_dir=os.path.join(
                            FLAGS.identity_specific_dataset_dir, dataset_name),
                        shuffle_buffer_size=data_buffer_size)
        for dataset_name in utils.IDENTITY_TYPES:
            if utils.NUM_EXAMPLES[dataset_name]['test'] > 100:
                test_dataset_builders[
                    dataset_name] = ds.CivilCommentsIdentitiesDataset(
                        split='test',
                        data_dir=os.path.join(FLAGS.identity_type_dataset_dir,
                                              dataset_name),
                        shuffle_buffer_size=data_buffer_size)

    class_weight = utils.create_class_weight(train_dataset_builders,
                                             test_dataset_builders)
    logging.info('class_weight: %s', str(class_weight))

    ds_info = train_dataset_builder.tfds_info
    # Positive and negative classes.
    num_classes = ds_info.metadata['num_classes']

    train_datasets = {}
    dataset_steps_per_epoch = {}
    total_steps_per_epoch = 0

    # TODO(jereliu): Apply strategy.experimental_distribute_dataset to the
    # dataset_builders.
    for dataset_name, dataset_builder in train_dataset_builders.items():
        train_datasets[dataset_name] = dataset_builder.load(
            batch_size=FLAGS.per_core_batch_size)
        dataset_steps_per_epoch[dataset_name] = (
            dataset_builder.num_examples // batch_size)
        total_steps_per_epoch += dataset_steps_per_epoch[dataset_name]

    test_datasets = {}
    steps_per_eval = {}
    for dataset_name, dataset_builder in test_dataset_builders.items():
        test_datasets[dataset_name] = dataset_builder.load(
            batch_size=test_batch_size)
        if dataset_name in ['ind', 'ood', 'ood_identity']:
            steps_per_eval[dataset_name] = (dataset_builder.num_examples //
                                            test_batch_size)
        else:
            steps_per_eval[dataset_name] = (
                utils.NUM_EXAMPLES[dataset_name]['test'] // test_batch_size)

    if FLAGS.use_bfloat16:
        tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building %s model', FLAGS.model_family)

        bert_config_dir, bert_ckpt_dir = utils.resolve_bert_ckpt_and_config_dir(
            FLAGS.bert_model_type, FLAGS.bert_dir, FLAGS.bert_config_dir,
            FLAGS.bert_ckpt_dir)
        bert_config = utils.create_config(bert_config_dir)
        bert_config.hidden_dropout_prob = FLAGS.dropout_rate
        bert_config.attention_probs_dropout_prob = FLAGS.dropout_rate
        model, bert_encoder = ub.models.bert_dropout_model(
            num_classes=num_classes,
            bert_config=bert_config,
            use_mc_dropout_mha=FLAGS.use_mc_dropout_mha,
            use_mc_dropout_att=FLAGS.use_mc_dropout_att,
            use_mc_dropout_ffn=FLAGS.use_mc_dropout_ffn,
            use_mc_dropout_output=FLAGS.use_mc_dropout_output,
            channel_wise_dropout_mha=FLAGS.channel_wise_dropout_mha,
            channel_wise_dropout_att=FLAGS.channel_wise_dropout_att,
            channel_wise_dropout_ffn=FLAGS.channel_wise_dropout_ffn)

        # Create an AdamW optimizer with beta_2=0.999, epsilon=1e-6.
        optimizer = utils.create_optimizer(
            FLAGS.base_learning_rate,
            steps_per_epoch=total_steps_per_epoch,
            epochs=FLAGS.train_epochs,
            warmup_proportion=FLAGS.warmup_proportion,
            beta_1=1.0 - FLAGS.one_minus_momentum)

        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())

        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.Accuracy(),
            'train/accuracy_weighted':
            tf.keras.metrics.Accuracy(),
            'train/auroc':
            tf.keras.metrics.AUC(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_ece_bins),
            'train/precision':
            tf.keras.metrics.Precision(),
            'train/recall':
            tf.keras.metrics.Recall(),
            'train/f1':
            tfa_metrics.F1Score(num_classes=num_classes,
                                average='micro',
                                threshold=FLAGS.ece_label_threshold),
        }

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        if FLAGS.prediction_mode:
            latest_checkpoint = tf.train.latest_checkpoint(
                FLAGS.eval_checkpoint_dir)
        else:
            latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy(
            ) // total_steps_per_epoch
        elif FLAGS.model_family.lower() == 'bert':
            # load BERT from initial checkpoint
            bert_checkpoint = tf.train.Checkpoint(model=bert_encoder)
            bert_checkpoint.restore(
                bert_ckpt_dir).assert_existing_objects_matched()
            logging.info('Loaded BERT checkpoint %s', bert_ckpt_dir)

        metrics.update({
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/auroc':
            tf.keras.metrics.AUC(curve='ROC'),
            'test/aupr':
            tf.keras.metrics.AUC(curve='PR'),
            'test/brier':
            tf.keras.metrics.MeanSquaredError(),
            'test/brier_weighted':
            tf.keras.metrics.MeanSquaredError(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_ece_bins),
            'test/acc':
            tf.keras.metrics.Accuracy(),
            'test/acc_weighted':
            tf.keras.metrics.Accuracy(),
            'test/eval_time':
            tf.keras.metrics.Mean(),
            'test/precision':
            tf.keras.metrics.Precision(),
            'test/recall':
            tf.keras.metrics.Recall(),
            'test/f1':
            tfa_metrics.F1Score(num_classes=num_classes,
                                average='micro',
                                threshold=FLAGS.ece_label_threshold)
        })

        for policy in ('uncertainty', 'toxicity'):
            metrics.update({
                'test_{}/calibration_auroc'.format(policy):
                tc_metrics.CalibrationAUC(curve='ROC'),
                'test_{}/calibration_auprc'.format(policy):
                tc_metrics.CalibrationAUC(curve='PR')
            })

            for fraction in FLAGS.fractions:
                metrics.update({
                    'test_{}/collab_acc_{}'.format(policy, fraction):
                    rm.metrics.OracleCollaborativeAccuracy(
                        fraction=float(fraction),
                        num_bins=FLAGS.num_approx_bins),
                    'test_{}/abstain_prec_{}'.format(policy, fraction):
                    tc_metrics.AbstainPrecision(
                        abstain_fraction=float(fraction),
                        num_approx_bins=FLAGS.num_approx_bins),
                    'test_{}/abstain_recall_{}'.format(policy, fraction):
                    tc_metrics.AbstainRecall(
                        abstain_fraction=float(fraction),
                        num_approx_bins=FLAGS.num_approx_bins),
                    'test_{}/collab_auroc_{}'.format(policy, fraction):
                    tc_metrics.OracleCollaborativeAUC(
                        oracle_fraction=float(fraction),
                        num_bins=FLAGS.num_approx_bins),
                    'test_{}/collab_auprc_{}'.format(policy, fraction):
                    tc_metrics.OracleCollaborativeAUC(
                        oracle_fraction=float(fraction),
                        curve='PR',
                        num_bins=FLAGS.num_approx_bins),
                })

        for dataset_name, test_dataset in test_datasets.items():
            if dataset_name != 'ind':
                metrics.update({
                    'test/nll_{}'.format(dataset_name):
                    tf.keras.metrics.Mean(),
                    'test/auroc_{}'.format(dataset_name):
                    tf.keras.metrics.AUC(curve='ROC'),
                    'test/aupr_{}'.format(dataset_name):
                    tf.keras.metrics.AUC(curve='PR'),
                    'test/brier_{}'.format(dataset_name):
                    tf.keras.metrics.MeanSquaredError(),
                    'test/brier_weighted_{}'.format(dataset_name):
                    tf.keras.metrics.MeanSquaredError(),
                    'test/ece_{}'.format(dataset_name):
                    rm.metrics.ExpectedCalibrationError(
                        num_bins=FLAGS.num_ece_bins),
                    'test/acc_{}'.format(dataset_name):
                    tf.keras.metrics.Accuracy(),
                    'test/acc_weighted_{}'.format(dataset_name):
                    tf.keras.metrics.Accuracy(),
                    'test/eval_time_{}'.format(dataset_name):
                    tf.keras.metrics.Mean(),
                    'test/precision_{}'.format(dataset_name):
                    tf.keras.metrics.Precision(),
                    'test/recall_{}'.format(dataset_name):
                    tf.keras.metrics.Recall(),
                    'test/f1_{}'.format(dataset_name):
                    tfa_metrics.F1Score(num_classes=num_classes,
                                        average='micro',
                                        threshold=FLAGS.ece_label_threshold)
                })

                for policy in ('uncertainty', 'toxicity'):
                    metrics.update({
                        'test_{}/calibration_auroc_{}'.format(
                            policy, dataset_name):
                        tc_metrics.CalibrationAUC(curve='ROC'),
                        'test_{}/calibration_auprc_{}'.format(
                            policy, dataset_name):
                        tc_metrics.CalibrationAUC(curve='PR'),
                    })

                    for fraction in FLAGS.fractions:
                        metrics.update({
                            'test_{}/collab_acc_{}_{}'.format(
                                policy, fraction, dataset_name):
                            rm.metrics.OracleCollaborativeAccuracy(
                                fraction=float(fraction),
                                num_bins=FLAGS.num_approx_bins),
                            'test_{}/abstain_prec_{}_{}'.format(
                                policy, fraction, dataset_name):
                            tc_metrics.AbstainPrecision(
                                abstain_fraction=float(fraction),
                                num_approx_bins=FLAGS.num_approx_bins),
                            'test_{}/abstain_recall_{}_{}'.format(
                                policy, fraction, dataset_name):
                            tc_metrics.AbstainRecall(
                                abstain_fraction=float(fraction),
                                num_approx_bins=FLAGS.num_approx_bins),
                            'test_{}/collab_auroc_{}_{}'.format(
                                policy, fraction, dataset_name):
                            tc_metrics.OracleCollaborativeAUC(
                                oracle_fraction=float(fraction),
                                num_bins=FLAGS.num_approx_bins),
                            'test_{}/collab_auprc_{}_{}'.format(
                                policy, fraction, dataset_name):
                            tc_metrics.OracleCollaborativeAUC(
                                oracle_fraction=float(fraction),
                                curve='PR',
                                num_bins=FLAGS.num_approx_bins),
                        })

    @tf.function
    def generate_sample_weight(labels, class_weight, label_threshold=0.7):
        """Generate sample weight for weighted accuracy calculation."""
        if label_threshold != 0.7:
            logging.warning(
                'The class weight was based on `label_threshold` = 0.7, '
                'and weighted accuracy/brier will be meaningless if '
                '`label_threshold` is not equal to this value, which is '
                'recommended by Jigsaw Conversation AI team.')
        labels_int = tf.cast(labels > label_threshold, tf.int32)
        sample_weight = tf.gather(class_weight, labels_int)
        return sample_weight

    @tf.function
    def train_step(iterator, dataset_name, num_steps):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            features, labels, _ = utils.create_feature_and_label(inputs)

            with tf.GradientTape() as tape:
                logits = model(features, training=True)

                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                loss_logits = tf.squeeze(logits, axis=1)
                if FLAGS.loss_type == 'cross_entropy':
                    logging.info('Using cross entropy loss')
                    negative_log_likelihood = tf.nn.sigmoid_cross_entropy_with_logits(
                        labels, loss_logits)
                elif FLAGS.loss_type == 'focal_cross_entropy':
                    logging.info('Using focal cross entropy loss')
                    negative_log_likelihood = tfa_losses.sigmoid_focal_crossentropy(
                        labels,
                        loss_logits,
                        alpha=FLAGS.focal_loss_alpha,
                        gamma=FLAGS.focal_loss_gamma,
                        from_logits=True)
                elif FLAGS.loss_type == 'mse':
                    logging.info('Using mean squared error loss')
                    loss_probs = tf.nn.sigmoid(loss_logits)
                    negative_log_likelihood = tf.keras.losses.mean_squared_error(
                        labels, loss_probs)
                elif FLAGS.loss_type == 'mae':
                    logging.info('Using mean absolute error loss')
                    loss_probs = tf.nn.sigmoid(loss_logits)
                    negative_log_likelihood = tf.keras.losses.mean_absolute_error(
                        labels, loss_probs)

                negative_log_likelihood = tf.reduce_mean(
                    negative_log_likelihood)

                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.sigmoid(logits)
            # Cast labels to discrete for ECE computation.
            ece_labels = tf.cast(labels > FLAGS.ece_label_threshold,
                                 tf.float32)
            one_hot_labels = tf.one_hot(tf.cast(ece_labels, tf.int32),
                                        depth=num_classes)
            ece_probs = tf.concat([1. - probs, probs], axis=1)
            auc_probs = tf.squeeze(probs, axis=1)
            pred_labels = tf.math.argmax(ece_probs, axis=-1)

            sample_weight = generate_sample_weight(
                labels, class_weight['train/{}'.format(dataset_name)],
                FLAGS.ece_label_threshold)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, pred_labels)
            metrics['train/accuracy_weighted'].update_state(
                ece_labels, pred_labels, sample_weight=sample_weight)
            metrics['train/auroc'].update_state(labels, auc_probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/ece'].add_batch(ece_probs, label=ece_labels)
            metrics['train/precision'].update_state(ece_labels, pred_labels)
            metrics['train/recall'].update_state(ece_labels, pred_labels)
            metrics['train/f1'].update_state(one_hot_labels, ece_probs)

        for _ in tf.range(tf.cast(num_steps, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn to log metrics."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            features, labels, _ = utils.create_feature_and_label(inputs)

            eval_start_time = time.time()
            logits = model(features, training=False)
            eval_time = (time.time() -
                         eval_start_time) / FLAGS.per_core_batch_size

            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.sigmoid(logits)
            # Cast labels to discrete for ECE computation.
            ece_labels = tf.cast(labels > FLAGS.ece_label_threshold,
                                 tf.float32)
            one_hot_labels = tf.one_hot(tf.cast(ece_labels, tf.int32),
                                        depth=num_classes)
            ece_probs = tf.concat([1. - probs, probs], axis=1)
            pred_labels = tf.math.argmax(ece_probs, axis=-1)
            auc_probs = tf.squeeze(probs, axis=1)

            loss_logits = tf.squeeze(logits, axis=1)
            negative_log_likelihood = tf.reduce_mean(
                tf.nn.sigmoid_cross_entropy_with_logits(labels, loss_logits))

            # Use normalized binary predictive variance as the confidence score.
            # Since the prediction variance p*(1-p) is within range (0, 0.25),
            # normalize it by maximum value so the confidence is between (0, 1).
            calib_confidence = 1. - probs * (1. - probs) / .25

            sample_weight = generate_sample_weight(
                labels, class_weight['test/{}'.format(dataset_name)],
                FLAGS.ece_label_threshold)
            if dataset_name == 'ind':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/auroc'].update_state(labels, auc_probs)
                metrics['test/aupr'].update_state(labels, auc_probs)
                metrics['test/brier'].update_state(labels, auc_probs)
                metrics['test/brier_weighted'].update_state(
                    tf.expand_dims(labels, -1),
                    probs,
                    sample_weight=sample_weight)
                metrics['test/ece'].add_batch(ece_probs, label=ece_labels)
                metrics['test/acc'].update_state(ece_labels, pred_labels)
                metrics['test/acc_weighted'].update_state(
                    ece_labels, pred_labels, sample_weight=sample_weight)
                metrics['test/eval_time'].update_state(eval_time)
                metrics['test/precision'].update_state(ece_labels, pred_labels)
                metrics['test/recall'].update_state(ece_labels, pred_labels)
                metrics['test/f1'].update_state(one_hot_labels, ece_probs)

                for policy in ('uncertainty', 'toxicity'):
                    # calib_confidence or decreasing toxicity score.
                    confidence = 1. - probs if policy == 'toxicity' else calib_confidence
                    binning_confidence = tf.squeeze(confidence)

                    metrics['test_{}/calibration_auroc'.format(
                        policy)].update_state(ece_labels, pred_labels,
                                              confidence)
                    metrics['test_{}/calibration_auprc'.format(
                        policy)].update_state(ece_labels, pred_labels,
                                              confidence)

                    for fraction in FLAGS.fractions:
                        metrics['test_{}/collab_acc_{}'.format(
                            policy, fraction)].add_batch(
                                ece_probs,
                                label=ece_labels,
                                custom_binning_score=binning_confidence)
                        metrics['test_{}/abstain_prec_{}'.format(
                            policy,
                            fraction)].update_state(ece_labels, pred_labels,
                                                    confidence)
                        metrics['test_{}/abstain_recall_{}'.format(
                            policy,
                            fraction)].update_state(ece_labels, pred_labels,
                                                    confidence)
                        metrics['test_{}/collab_auroc_{}'.format(
                            policy, fraction)].update_state(
                                labels,
                                auc_probs,
                                custom_binning_score=binning_confidence)
                        metrics['test_{}/collab_auprc_{}'.format(
                            policy, fraction)].update_state(
                                labels,
                                auc_probs,
                                custom_binning_score=binning_confidence)

            else:
                metrics['test/nll_{}'.format(dataset_name)].update_state(
                    negative_log_likelihood)
                metrics['test/auroc_{}'.format(dataset_name)].update_state(
                    labels, auc_probs)
                metrics['test/aupr_{}'.format(dataset_name)].update_state(
                    labels, auc_probs)
                metrics['test/brier_{}'.format(dataset_name)].update_state(
                    labels, auc_probs)
                metrics['test/brier_weighted_{}'.format(
                    dataset_name)].update_state(tf.expand_dims(labels, -1),
                                                probs,
                                                sample_weight=sample_weight)
                metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    ece_probs, label=ece_labels)
                metrics['test/acc_{}'.format(dataset_name)].update_state(
                    ece_labels, pred_labels)
                metrics['test/acc_weighted_{}'.format(
                    dataset_name)].update_state(ece_labels,
                                                pred_labels,
                                                sample_weight=sample_weight)
                metrics['test/eval_time_{}'.format(dataset_name)].update_state(
                    eval_time)
                metrics['test/precision_{}'.format(dataset_name)].update_state(
                    ece_labels, pred_labels)
                metrics['test/recall_{}'.format(dataset_name)].update_state(
                    ece_labels, pred_labels)
                metrics['test/f1_{}'.format(dataset_name)].update_state(
                    one_hot_labels, ece_probs)

                for policy in ('uncertainty', 'toxicity'):
                    # calib_confidence or decreasing toxicity score.
                    confidence = 1. - probs if policy == 'toxicity' else calib_confidence
                    binning_confidence = tf.squeeze(confidence)

                    metrics['test_{}/calibration_auroc_{}'.format(
                        policy,
                        dataset_name)].update_state(ece_labels, pred_labels,
                                                    confidence)
                    metrics['test_{}/calibration_auprc_{}'.format(
                        policy,
                        dataset_name)].update_state(ece_labels, pred_labels,
                                                    confidence)

                    for fraction in FLAGS.fractions:
                        metrics['test_{}/collab_acc_{}_{}'.format(
                            policy, fraction, dataset_name)].add_batch(
                                ece_probs,
                                label=ece_labels,
                                custom_binning_score=binning_confidence)
                        metrics['test_{}/abstain_prec_{}_{}'.format(
                            policy, fraction, dataset_name)].update_state(
                                ece_labels, pred_labels, confidence)
                        metrics['test_{}/abstain_recall_{}_{}'.format(
                            policy, fraction, dataset_name)].update_state(
                                ece_labels, pred_labels, confidence)
                        metrics['test_{}/collab_auroc_{}_{}'.format(
                            policy, fraction, dataset_name)].update_state(
                                labels,
                                auc_probs,
                                custom_binning_score=binning_confidence)
                        metrics['test_{}/collab_auprc_{}_{}'.format(
                            policy, fraction, dataset_name)].update_state(
                                labels,
                                auc_probs,
                                custom_binning_score=binning_confidence)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def final_eval_step(iterator):
        """Final Evaluation StepFn to save prediction to directory."""
        def step_fn(inputs):
            bert_features, labels, additional_labels = utils.create_feature_and_label(
                inputs)
            logits = model(bert_features, training=False)
            features = inputs['input_ids']
            return features, logits, labels, additional_labels

        (per_replica_texts, per_replica_logits, per_replica_labels,
         per_replica_additional_labels) = (strategy.run(
             step_fn, args=(next(iterator), )))

        if strategy.num_replicas_in_sync > 1:
            texts_list = tf.concat(per_replica_texts.values, axis=0)
            logits_list = tf.concat(per_replica_logits.values, axis=0)
            labels_list = tf.concat(per_replica_labels.values, axis=0)
            additional_labels_dict = {}
            for additional_label in utils.IDENTITY_LABELS:
                if additional_label in per_replica_additional_labels:
                    additional_labels_dict[additional_label] = tf.concat(
                        per_replica_additional_labels[additional_label],
                        axis=0)
        else:
            texts_list = per_replica_texts
            logits_list = per_replica_logits
            labels_list = per_replica_labels
            additional_labels_dict = {}
            for additional_label in utils.IDENTITY_LABELS:
                if additional_label in per_replica_additional_labels:
                    additional_labels_dict[
                        additional_label] = per_replica_additional_labels[
                            additional_label]

        return texts_list, logits_list, labels_list, additional_labels_dict

    if FLAGS.prediction_mode:
        # Prediction and exit.
        for dataset_name, test_dataset in test_datasets.items():
            test_iterator = iter(test_dataset)  # pytype: disable=wrong-arg-types
            message = 'Final eval on dataset {}'.format(dataset_name)
            logging.info(message)

            texts_all = []
            logits_all = []
            labels_all = []
            additional_labels_all_dict = {}
            if 'identity' in dataset_name:
                for identity_label_name in utils.IDENTITY_LABELS:
                    additional_labels_all_dict[identity_label_name] = []

            try:
                with tf.experimental.async_scope():
                    for step in range(steps_per_eval[dataset_name]):
                        if step % 20 == 0:
                            message = 'Starting to run eval step {}/{} of dataset: {}'.format(
                                step, steps_per_eval[dataset_name],
                                dataset_name)
                            logging.info(message)

                        (text_step, logits_step, labels_step,
                         additional_labels_dict_step
                         ) = final_eval_step(test_iterator)

                        texts_all.append(text_step)
                        logits_all.append(logits_step)
                        labels_all.append(labels_step)
                        if 'identity' in dataset_name:
                            for identity_label_name in utils.IDENTITY_LABELS:
                                additional_labels_all_dict[
                                    identity_label_name].append(
                                        additional_labels_dict_step[
                                            identity_label_name])

            except (StopIteration, tf.errors.OutOfRangeError):
                tf.experimental.async_clear_error()
                logging.info('Done with eval on %s', dataset_name)

            texts_all = tf.concat(texts_all, axis=0)
            logits_all = tf.concat(logits_all, axis=0)
            labels_all = tf.concat(labels_all, axis=0)
            additional_labels_all = []
            if additional_labels_all_dict:
                for identity_label_name in utils.IDENTITY_LABELS:
                    additional_labels_all.append(
                        tf.concat(
                            additional_labels_all_dict[identity_label_name],
                            axis=0))
            additional_labels_all = tf.convert_to_tensor(additional_labels_all)

            utils.save_prediction(texts_all.numpy(),
                                  path=os.path.join(
                                      FLAGS.output_dir,
                                      'texts_{}'.format(dataset_name)))
            utils.save_prediction(labels_all.numpy(),
                                  path=os.path.join(
                                      FLAGS.output_dir,
                                      'labels_{}'.format(dataset_name)))
            utils.save_prediction(logits_all.numpy(),
                                  path=os.path.join(
                                      FLAGS.output_dir,
                                      'logits_{}'.format(dataset_name)))
            if 'identity' in dataset_name:
                utils.save_prediction(
                    additional_labels_all.numpy(),
                    path=os.path.join(
                        FLAGS.output_dir,
                        'additional_labels_{}'.format(dataset_name)))
            logging.info('Done with testing on %s', dataset_name)

    else:
        # Execute train / eval loop.
        start_time = time.time()
        train_iterators = {}
        for dataset_name, train_dataset in train_datasets.items():
            train_iterators[dataset_name] = iter(train_dataset)
        for epoch in range(initial_epoch, FLAGS.train_epochs):
            logging.info('Starting to run epoch: %s', epoch)
            for dataset_name, train_iterator in train_iterators.items():
                train_step(train_iterator, dataset_name,
                           dataset_steps_per_epoch[dataset_name])

                current_step = (epoch * total_steps_per_epoch +
                                dataset_steps_per_epoch[dataset_name])
                max_steps = total_steps_per_epoch * FLAGS.train_epochs
                time_elapsed = time.time() - start_time
                steps_per_sec = float(current_step) / time_elapsed
                eta_seconds = (max_steps - current_step) / steps_per_sec
                message = (
                    '{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                    'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                        current_step / max_steps, epoch + 1,
                        FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                        time_elapsed / 60))
                logging.info(message)

            if epoch % FLAGS.evaluation_interval == 0:
                for dataset_name, test_dataset in test_datasets.items():
                    test_iterator = iter(test_dataset)  # pytype: disable=wrong-arg-types
                    logging.info('Testing on dataset %s', dataset_name)

                    try:
                        with tf.experimental.async_scope():
                            for step in range(steps_per_eval[dataset_name]):
                                if step % 20 == 0:
                                    logging.info(
                                        'Starting to run eval step %s/%s of epoch: %s',
                                        step, steps_per_eval[dataset_name],
                                        epoch)
                                test_step(test_iterator, dataset_name)
                    except (StopIteration, tf.errors.OutOfRangeError):
                        tf.experimental.async_clear_error()
                        logging.info('Done with testing on %s', dataset_name)

                logging.info('Train Loss: %.4f, AUROC: %.4f',
                             metrics['train/loss'].result(),
                             metrics['train/auroc'].result())
                logging.info('Test NLL: %.4f, AUROC: %.4f',
                             metrics['test/negative_log_likelihood'].result(),
                             metrics['test/auroc'].result())

                # record results
                total_results = {
                    name: metric.result()
                    for name, metric in metrics.items()
                }
                # Metrics from Robustness Metrics (like ECE) will return a dict with a
                # single key/value, instead of a scalar.
                total_results = {
                    k: (list(v.values())[0] if isinstance(v, dict) else v)
                    for k, v in total_results.items()
                }

                with summary_writer.as_default():
                    for name, result in total_results.items():
                        tf.summary.scalar(name, result, step=epoch + 1)

            for name, metric in metrics.items():
                metric.reset_states()

            checkpoint_interval = min(FLAGS.checkpoint_interval,
                                      FLAGS.train_epochs)
            if checkpoint_interval > 0 and (epoch +
                                            1) % checkpoint_interval == 0:
                checkpoint_name = checkpoint.save(
                    os.path.join(FLAGS.output_dir, 'checkpoint'))
                logging.info('Saved checkpoint to %s', checkpoint_name)

        # Save model in SavedModel format on exit.
        final_save_name = os.path.join(FLAGS.output_dir, 'model')
        model.save(final_save_name)
        logging.info('Saved model to %s', final_save_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'dropout_rate': FLAGS.dropout_rate,
        })
    def testPROracleFractionTwoThirds(self):
        y_true = np.array([0., 0., 1., 1., 0., 1., 1., 0.])
        y_pred = np.array([0.31, 0.33, 0.42, 0.58, 0.69, 0.76, 0.84, 0.87])

        num_thresholds = 5  # -1e-7, 0.25, 0.5, 0.75, 1.0000001
        num_bins = 3
        curve = 'PR'
        oracle_auc = metrics.OracleCollaborativeAUC(
            oracle_fraction=0.67,  # floor(0.67 * 8) = 5 examples sent to oracle
            num_thresholds=num_thresholds,
            num_bins=num_bins,
            curve=curve)

        result = oracle_auc(y_true, y_pred)
        self.assertAllClose(
            oracle_auc.binned_true_positives,
            # y_true's positives are 0.42, 0.58, 0.76, and 0.84 in y_pred.
            np.array([
                [0., 2., 2.],  # Threshold -1e-7; bins are unmodified
                [2., 2., 0.],  # Threshold 0.25; bins [0, 0.58), [0.58, 0.91)
                [2., 1., 0.],  # Threshold 0.5: 0.42 is now a false positive.
                [2., 0.,
                 0.],  # Threshold 0.75: only 0.76 and 0.84 are positive.
                [0., 0., 0.],  # Threshold 1.0000001: no positives.
            ]))
        self.assertAllClose(
            oracle_auc.binned_true_negatives,
            # The possible true negatives are 0.31, 0.33, 0.69, and 0.87.
            np.array([
                [0., 0., 0.],  # There are no negatives for threshold -1e-7.
                [0., 0., 0.],  # Threshold 0.25: still no negatives.
                [2., 0., 0.],  # Threshold 0.5: 0.31 and 0.33 are negative.
                [1., 2., 0.],  # Threshold 0.75: only 0.69 in first bin.
                [2., 0.,
                 2.],  # Threshold 1.0000001: 0.76 and 0.84 in first bin.
            ]))
        self.assertAllClose(
            oracle_auc.binned_false_positives,
            # Compare these values with oracle_auc.binned_true_negatives.
            # For example, the total across their rows must always be 4.
            np.array([
                [2., 0.,
                 2.],  # 0.76 and 0.84 in bin 3 (greater than -1e-7 + 0.66).
                [2., 2.,
                 0.],  # Threshold 0.25: 0.76 and 0.84 move to second bin.
                [1., 1.,
                 0.],  # Threshold 0.5: 0.76 (0.84) in first (second) bin.
                [1., 0.,
                 0.],  # Threshold 0.75: only 0.87 remains in first bin.
                [0., 0., 0.],  # Threshold 1.0000001: no more positives.
            ]))
        self.assertAllClose(
            oracle_auc.binned_false_negatives,
            # Compare these values with oracle_auc.binned_true_positives.
            np.array([
                [0., 0., 0.],  # No negatives
                [0., 0., 0.],  # No negatives
                [1., 0., 0.],  # Threshold 0.5: only 0.42 is below threshold.
                [2., 0.,
                 0.],  # Threshold 0.75: 0.42 still in bin 1; 0.58 joins it.
                [2., 2.,
                 0.],  # Threshold 1.0000001: 0.42 and 0.58 in second bin.
            ]))

        # The first and last threshold are outside [0, 1] and are never corrected.
        # Second threshold: 0.5 corrected from fp to tn
        # Third threshold: 0.83 corrected from fp and fn each to tp and tn
        # Fourth threshold: 0.83 corrected from fp->tn, 1.67 corrected from fn->tp
        self.assertAllClose(oracle_auc.true_positives,
                            np.array([4., 4., 3. + 5 / 6, 2. + 5 / 3, 0.]))
        self.assertAllClose(
            oracle_auc.true_negatives,
            np.array([0., 2. + 0.5, 2. + 5 / 6, 3. + 5 / 6, 4.]))
        self.assertAllClose(
            oracle_auc.false_positives,
            np.array([4., 2. - 0.5, 2. - 5 / 6, 1. - 5 / 6, 0.]))
        self.assertAllClose(oracle_auc.false_negatives,
                            np.array([0., 0., 1. - 5 / 6, 2. - 5 / 3, 4.]))

        self.assertEqual(result, 0.9434595)
def main(argv):
    del argv  # unused arg
    if not FLAGS.use_gpu:
        raise ValueError('Only GPU is currently supported.')
    if FLAGS.num_cores > 1:
        raise ValueError('Only a single accelerator is currently supported.')

    tf.random.set_seed(FLAGS.seed)
    logging.info('Model checkpoint will be saved at %s', FLAGS.output_dir)
    tf.io.gfile.makedirs(FLAGS.output_dir)

    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    test_batch_size = batch_size
    data_buffer_size = batch_size * 10

    ind_dataset_builder = ds.WikipediaToxicityDataset(
        split='test',
        data_dir=FLAGS.in_dataset_dir,
        shuffle_buffer_size=data_buffer_size)
    ood_dataset_builder = ds.CivilCommentsDataset(
        split='test',
        data_dir=FLAGS.ood_dataset_dir,
        shuffle_buffer_size=data_buffer_size)
    ood_identity_dataset_builder = ds.CivilCommentsIdentitiesDataset(
        split='test',
        data_dir=FLAGS.identity_dataset_dir,
        shuffle_buffer_size=data_buffer_size)

    test_dataset_builders = {
        'ind': ind_dataset_builder,
        'ood': ood_dataset_builder,
        'ood_identity': ood_identity_dataset_builder,
    }
    if FLAGS.prediction_mode and FLAGS.identity_prediction:
        for dataset_name in utils.IDENTITY_LABELS:
            if utils.NUM_EXAMPLES[dataset_name]['test'] > 100:
                test_dataset_builders[
                    dataset_name] = ds.CivilCommentsIdentitiesDataset(
                        split='test',
                        data_dir=os.path.join(
                            FLAGS.identity_specific_dataset_dir, dataset_name),
                        shuffle_buffer_size=data_buffer_size)
        for dataset_name in utils.IDENTITY_TYPES:
            if utils.NUM_EXAMPLES[dataset_name]['test'] > 100:
                test_dataset_builders[
                    dataset_name] = ds.CivilCommentsIdentitiesDataset(
                        split='test',
                        data_dir=os.path.join(FLAGS.identity_type_dataset_dir,
                                              dataset_name),
                        shuffle_buffer_size=data_buffer_size)

    class_weight = utils.create_class_weight(
        test_dataset_builders=test_dataset_builders)
    logging.info('class_weight: %s', str(class_weight))

    ds_info = ind_dataset_builder.tfds_info
    feature_size = _MAX_SEQ_LENGTH
    # Positive and negative classes.
    num_classes = ds_info.metadata['num_classes']

    test_datasets = {}
    steps_per_eval = {}
    for dataset_name, dataset_builder in test_dataset_builders.items():
        test_datasets[dataset_name] = dataset_builder.load(
            batch_size=test_batch_size)
        if dataset_name in ['ind', 'ood', 'ood_identity']:
            steps_per_eval[dataset_name] = (dataset_builder.num_examples //
                                            test_batch_size)
        else:
            steps_per_eval[dataset_name] = (
                utils.NUM_EXAMPLES[dataset_name]['test'] // test_batch_size)

    logging.info('Building %s model', FLAGS.model_family)

    bert_config_dir, _ = utils.resolve_bert_ckpt_and_config_dir(
        FLAGS.bert_model_type, FLAGS.bert_dir, FLAGS.bert_config_dir,
        FLAGS.bert_ckpt_dir)
    bert_config = utils.create_config(bert_config_dir)
    model, _ = ub.models.bert_model(num_classes=num_classes,
                                    max_seq_length=feature_size,
                                    bert_config=bert_config)

    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = tf.io.gfile.glob(
        os.path.join(FLAGS.checkpoint_dir, '**/*.index'))
    ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
    if FLAGS.num_models > len(ensemble_filenames):
        raise ValueError('Number of models to be included in the ensemble '
                         'should be less than total number of models in '
                         'the checkpoint_dir.')
    ensemble_filenames = ensemble_filenames[:FLAGS.num_models]
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))
    checkpoint = tf.train.Checkpoint(model=model)

    # Write model predictions to files.
    num_datasets = len(test_datasets)
    for m, ensemble_filename in enumerate(ensemble_filenames):
        checkpoint.restore(ensemble_filename).assert_existing_objects_matched()
        for n, (dataset_name,
                test_dataset) in enumerate(test_datasets.items()):
            filename = '{dataset}_{member}.npy'.format(dataset=dataset_name,
                                                       member=m)
            filename = os.path.join(FLAGS.output_dir, filename)
            if not tf.io.gfile.exists(filename):
                logits = []
                test_iterator = iter(test_dataset)
                for step in range(steps_per_eval[dataset_name]):
                    try:
                        inputs = next(test_iterator)
                    except StopIteration:
                        continue
                    features, labels, _ = utils.create_feature_and_label(
                        inputs)
                    logits.append(model(features, training=False))

                logits = tf.concat(logits, axis=0)
                with tf.io.gfile.GFile(filename, 'w') as f:
                    np.save(f, logits.numpy())
            percent = (m * num_datasets +
                       (n + 1)) / (ensemble_size * num_datasets)
            message = (
                '{:.1%} completion for prediction: ensemble member {:d}/{:d}. '
                'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size,
                                           n + 1, num_datasets))
            logging.info(message)

    metrics = {
        'test/negative_log_likelihood':
        tf.keras.metrics.Mean(),
        'test/auroc':
        tf.keras.metrics.AUC(curve='ROC'),
        'test/aupr':
        tf.keras.metrics.AUC(curve='PR'),
        'test/brier':
        tf.keras.metrics.MeanSquaredError(),
        'test/brier_weighted':
        tf.keras.metrics.MeanSquaredError(),
        'test/ece':
        rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_ece_bins),
        'test/acc':
        tf.keras.metrics.Accuracy(),
        'test/acc_weighted':
        tf.keras.metrics.Accuracy(),
        'test/precision':
        tf.keras.metrics.Precision(),
        'test/recall':
        tf.keras.metrics.Recall(),
        'test/f1':
        tfa_metrics.F1Score(num_classes=num_classes,
                            average='micro',
                            threshold=FLAGS.ece_label_threshold)
    }

    for policy in ('uncertainty', 'toxicity'):
        metrics.update({
            'test_{}/calibration_auroc'.format(policy):
            tc_metrics.CalibrationAUC(curve='ROC'),
            'test_{}/calibration_auprc'.format(policy):
            tc_metrics.CalibrationAUC(curve='PR')
        })

        for fraction in FLAGS.fractions:
            metrics.update({
                'test_{}/collab_acc_{}'.format(policy, fraction):
                rm.metrics.OracleCollaborativeAccuracy(
                    fraction=float(fraction), num_bins=FLAGS.num_approx_bins),
                'test_{}/abstain_prec_{}'.format(policy, fraction):
                tc_metrics.AbstainPrecision(
                    abstain_fraction=float(fraction),
                    num_approx_bins=FLAGS.num_approx_bins),
                'test_{}/abstain_recall_{}'.format(policy, fraction):
                tc_metrics.AbstainRecall(
                    abstain_fraction=float(fraction),
                    num_approx_bins=FLAGS.num_approx_bins),
                'test_{}/collab_auroc_{}'.format(policy, fraction):
                tc_metrics.OracleCollaborativeAUC(
                    oracle_fraction=float(fraction),
                    num_bins=FLAGS.num_approx_bins),
                'test_{}/collab_auprc_{}'.format(policy, fraction):
                tc_metrics.OracleCollaborativeAUC(
                    oracle_fraction=float(fraction),
                    curve='PR',
                    num_bins=FLAGS.num_approx_bins),
            })

    for dataset_name, test_dataset in test_datasets.items():
        if dataset_name != 'ind':
            metrics.update({
                'test/nll_{}'.format(dataset_name):
                tf.keras.metrics.Mean(),
                'test/auroc_{}'.format(dataset_name):
                tf.keras.metrics.AUC(curve='ROC'),
                'test/aupr_{}'.format(dataset_name):
                tf.keras.metrics.AUC(curve='PR'),
                'test/brier_{}'.format(dataset_name):
                tf.keras.metrics.MeanSquaredError(),
                'test/brier_weighted_{}'.format(dataset_name):
                tf.keras.metrics.MeanSquaredError(),
                'test/ece_{}'.format(dataset_name):
                rm.metrics.ExpectedCalibrationError(
                    num_bins=FLAGS.num_ece_bins),
                'test/acc_weighted_{}'.format(dataset_name):
                tf.keras.metrics.Accuracy(),
                'test/acc_{}'.format(dataset_name):
                tf.keras.metrics.Accuracy(),
                'test/precision_{}'.format(dataset_name):
                tf.keras.metrics.Precision(),
                'test/recall_{}'.format(dataset_name):
                tf.keras.metrics.Recall(),
                'test/f1_{}'.format(dataset_name):
                tfa_metrics.F1Score(num_classes=num_classes,
                                    average='micro',
                                    threshold=FLAGS.ece_label_threshold)
            })

            for policy in ('uncertainty', 'toxicity'):
                metrics.update({
                    'test_{}/calibration_auroc_{}'.format(
                        policy, dataset_name):
                    tc_metrics.CalibrationAUC(curve='ROC'),
                    'test_{}/calibration_auprc_{}'.format(
                        policy, dataset_name):
                    tc_metrics.CalibrationAUC(curve='PR'),
                })

                for fraction in FLAGS.fractions:
                    metrics.update({
                        'test_{}/collab_acc_{}_{}'.format(
                            policy, fraction, dataset_name):
                        rm.metrics.OracleCollaborativeAccuracy(
                            fraction=float(fraction),
                            num_bins=FLAGS.num_approx_bins),
                        'test_{}/abstain_prec_{}_{}'.format(
                            policy, fraction, dataset_name):
                        tc_metrics.AbstainPrecision(
                            abstain_fraction=float(fraction),
                            num_approx_bins=FLAGS.num_approx_bins),
                        'test_{}/abstain_recall_{}_{}'.format(
                            policy, fraction, dataset_name):
                        tc_metrics.AbstainRecall(
                            abstain_fraction=float(fraction),
                            num_approx_bins=FLAGS.num_approx_bins),
                        'test_{}/collab_auroc_{}_{}'.format(
                            policy, fraction, dataset_name):
                        tc_metrics.OracleCollaborativeAUC(
                            oracle_fraction=float(fraction),
                            num_bins=FLAGS.num_approx_bins),
                        'test_{}/collab_auprc_{}_{}'.format(
                            policy, fraction, dataset_name):
                        tc_metrics.OracleCollaborativeAUC(
                            oracle_fraction=float(fraction),
                            curve='PR',
                            num_bins=FLAGS.num_approx_bins),
                    })

    @tf.function
    def generate_sample_weight(labels, class_weight, label_threshold=0.7):
        """Generate sample weight for weighted accuracy calculation."""
        if label_threshold != 0.7:
            logging.warning(
                'The class weight was based on `label_threshold` = 0.7, '
                'and weighted accuracy/brier will be meaningless if '
                '`label_threshold` is not equal to this value, which is '
                'recommended by Jigsaw Conversation AI team.')
        labels_int = tf.cast(labels > label_threshold, tf.int32)
        sample_weight = tf.gather(class_weight, labels_int)
        return sample_weight

    # Evaluate model predictions.
    for n, (dataset_name, test_dataset) in enumerate(test_datasets.items()):
        logits_dataset = []
        for m in range(ensemble_size):
            filename = '{dataset}_{member}.npy'.format(dataset=dataset_name,
                                                       member=m)
            filename = os.path.join(FLAGS.output_dir, filename)
            with tf.io.gfile.GFile(filename, 'rb') as f:
                logits_dataset.append(np.load(f))

        logits_dataset = tf.convert_to_tensor(logits_dataset)
        test_iterator = iter(test_dataset)
        texts_list = []
        logits_list = []
        labels_list = []
        # Use dict to collect additional labels specified by additional label names.
        # Here we use  `OrderedDict` to get consistent ordering for this dict so
        # we can retrieve the predictions for each identity labels in Colab.
        additional_labels_dict = collections.OrderedDict()
        for step in range(steps_per_eval[dataset_name]):
            try:
                inputs: Dict[str, tf.Tensor] = next(test_iterator)  # pytype: disable=annotation-type-mismatch
            except StopIteration:
                continue
            features, labels, additional_labels = (
                utils.create_feature_and_label(inputs))
            logits = logits_dataset[:, (step * batch_size):((step + 1) *
                                                            batch_size)]
            loss_logits = tf.squeeze(logits, axis=-1)
            negative_log_likelihood_metric = rm.metrics.EnsembleCrossEntropy(
                binary=True)
            negative_log_likelihood_metric.add_batch(loss_logits,
                                                     labels=labels)
            negative_log_likelihood = list(
                negative_log_likelihood_metric.result().values())[0]

            per_probs = tf.nn.sigmoid(logits)
            probs = tf.reduce_mean(per_probs, axis=0)
            # Cast labels to discrete for ECE computation
            ece_labels = tf.cast(labels > FLAGS.ece_label_threshold,
                                 tf.float32)
            one_hot_labels = tf.one_hot(tf.cast(ece_labels, tf.int32),
                                        depth=num_classes)
            ece_probs = tf.concat([1. - probs, probs], axis=1)
            pred_labels = tf.math.argmax(ece_probs, axis=-1)
            auc_probs = tf.squeeze(probs, axis=1)

            # Use normalized binary predictive variance as the confidence score.
            # Since the prediction variance p*(1-p) is within range (0, 0.25),
            # normalize it by maximum value so the confidence is between (0, 1).
            calib_confidence = 1. - probs * (1. - probs) / .25

            texts_list.append(inputs['input_ids'])
            logits_list.append(logits)
            labels_list.append(labels)
            if 'identity' in dataset_name:
                for identity_label_name in utils.IDENTITY_LABELS:
                    if identity_label_name not in additional_labels_dict:
                        additional_labels_dict[identity_label_name] = []
                    additional_labels_dict[identity_label_name].append(
                        additional_labels[identity_label_name].numpy())

            sample_weight = generate_sample_weight(
                labels, class_weight['test/{}'.format(dataset_name)],
                FLAGS.ece_label_threshold)
            if dataset_name == 'ind':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/auroc'].update_state(labels, auc_probs)
                metrics['test/aupr'].update_state(labels, auc_probs)
                metrics['test/brier'].update_state(labels, auc_probs)
                metrics['test/brier_weighted'].update_state(
                    tf.expand_dims(labels, -1),
                    probs,
                    sample_weight=sample_weight)
                metrics['test/ece'].add_batch(ece_probs, label=ece_labels)
                metrics['test/acc'].update_state(ece_labels, pred_labels)
                metrics['test/acc_weighted'].update_state(
                    ece_labels, pred_labels, sample_weight=sample_weight)
                metrics['test/precision'].update_state(ece_labels, pred_labels)
                metrics['test/recall'].update_state(ece_labels, pred_labels)
                metrics['test/f1'].update_state(one_hot_labels, ece_probs)

                for policy in ('uncertainty', 'toxicity'):
                    # calib_confidence or decreasing toxicity score.
                    confidence = 1. - probs if policy == 'toxicity' else calib_confidence
                    binning_confidence = tf.squeeze(confidence)

                    metrics['test_{}/calibration_auroc'.format(
                        policy)].update_state(ece_labels, pred_labels,
                                              confidence)
                    metrics['test_{}/calibration_auprc'.format(
                        policy)].update_state(ece_labels, pred_labels,
                                              confidence)

                    for fraction in FLAGS.fractions:
                        metrics['test_{}/collab_acc_{}'.format(
                            policy, fraction)].add_batch(
                                ece_probs,
                                label=ece_labels,
                                custom_binning_score=binning_confidence)
                        metrics['test_{}/abstain_prec_{}'.format(
                            policy,
                            fraction)].update_state(ece_labels, pred_labels,
                                                    confidence)
                        metrics['test_{}/abstain_recall_{}'.format(
                            policy,
                            fraction)].update_state(ece_labels, pred_labels,
                                                    confidence)
                        metrics['test_{}/collab_auroc_{}'.format(
                            policy, fraction)].update_state(
                                labels,
                                auc_probs,
                                custom_binning_score=binning_confidence)
                        metrics['test_{}/collab_auprc_{}'.format(
                            policy, fraction)].update_state(
                                labels,
                                auc_probs,
                                custom_binning_score=binning_confidence)

            else:
                metrics['test/nll_{}'.format(dataset_name)].update_state(
                    negative_log_likelihood)
                metrics['test/auroc_{}'.format(dataset_name)].update_state(
                    labels, auc_probs)
                metrics['test/aupr_{}'.format(dataset_name)].update_state(
                    labels, auc_probs)
                metrics['test/brier_{}'.format(dataset_name)].update_state(
                    labels, auc_probs)
                metrics['test/brier_weighted_{}'.format(
                    dataset_name)].update_state(tf.expand_dims(labels, -1),
                                                probs,
                                                sample_weight=sample_weight)
                metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    ece_probs, label=ece_labels)
                metrics['test/acc_{}'.format(dataset_name)].update_state(
                    ece_labels, pred_labels)
                metrics['test/acc_weighted_{}'.format(
                    dataset_name)].update_state(ece_labels,
                                                pred_labels,
                                                sample_weight=sample_weight)
                metrics['test/precision_{}'.format(dataset_name)].update_state(
                    ece_labels, pred_labels)
                metrics['test/recall_{}'.format(dataset_name)].update_state(
                    ece_labels, pred_labels)
                metrics['test/f1_{}'.format(dataset_name)].update_state(
                    one_hot_labels, ece_probs)

                for policy in ('uncertainty', 'toxicity'):
                    # calib_confidence or decreasing toxicity score.
                    confidence = 1. - probs if policy == 'toxicity' else calib_confidence
                    binning_confidence = tf.squeeze(confidence)

                    metrics['test_{}/calibration_auroc_{}'.format(
                        policy,
                        dataset_name)].update_state(ece_labels, pred_labels,
                                                    confidence)
                    metrics['test_{}/calibration_auprc_{}'.format(
                        policy,
                        dataset_name)].update_state(ece_labels, pred_labels,
                                                    confidence)

                    for fraction in FLAGS.fractions:
                        metrics['test_{}/collab_acc_{}_{}'.format(
                            policy, fraction, dataset_name)].add_batch(
                                ece_probs,
                                label=ece_labels,
                                custom_binning_score=binning_confidence)
                        metrics['test_{}/abstain_prec_{}_{}'.format(
                            policy, fraction, dataset_name)].update_state(
                                ece_labels, pred_labels, confidence)
                        metrics['test_{}/abstain_recall_{}_{}'.format(
                            policy, fraction, dataset_name)].update_state(
                                ece_labels, pred_labels, confidence)
                        metrics['test_{}/collab_auroc_{}_{}'.format(
                            policy, fraction, dataset_name)].update_state(
                                labels,
                                auc_probs,
                                custom_binning_score=binning_confidence)
                        metrics['test_{}/collab_auprc_{}_{}'.format(
                            policy, fraction, dataset_name)].update_state(
                                labels,
                                auc_probs,
                                custom_binning_score=binning_confidence)

        texts_all = tf.concat(texts_list, axis=0)
        logits_all = tf.concat(logits_list, axis=1)
        labels_all = tf.concat(labels_list, axis=0)
        additional_labels_all = []
        if additional_labels_dict:
            additional_labels_all = list(additional_labels_dict.values())

        utils.save_prediction(texts_all.numpy(),
                              path=os.path.join(
                                  FLAGS.output_dir,
                                  'texts_{}'.format(dataset_name)))
        utils.save_prediction(labels_all.numpy(),
                              path=os.path.join(
                                  FLAGS.output_dir,
                                  'labels_{}'.format(dataset_name)))
        utils.save_prediction(logits_all.numpy(),
                              path=os.path.join(
                                  FLAGS.output_dir,
                                  'logits_{}'.format(dataset_name)))
        if 'identity' in dataset_name:
            utils.save_prediction(
                np.array(additional_labels_all),
                path=os.path.join(FLAGS.output_dir,
                                  'additional_labels_{}'.format(dataset_name)))

        message = (
            '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format(
                (n + 1) / num_datasets, n + 1, num_datasets))
        logging.info(message)

    total_results = {name: metric.result() for name, metric in metrics.items()}
    # Metrics from Robustness Metrics (like ECE) will return a dict with a
    # single key/value, instead of a scalar.
    total_results = {
        k: (list(v.values())[0] if isinstance(v, dict) else v)
        for k, v in total_results.items()
    }
    logging.info('Metrics: %s', total_results)