Esempio n. 1
0
    def testResNetV1(self):
        tf.random.set_seed(83922)
        dataset_size = 10
        batch_size = 5
        input_shape = (32, 32, 1)
        num_classes = 2

        features = tf.random.normal((dataset_size, ) + input_shape)
        coeffs = tf.random.normal([tf.reduce_prod(input_shape), num_classes])
        net = tf.reshape(features, [dataset_size, -1])
        logits = tf.matmul(net, coeffs)
        labels = tf.random.categorical(logits, 1)
        features, labels = self.evaluate([features, labels])
        dataset = tf.data.Dataset.from_tensor_slices((features, labels))
        dataset = dataset.repeat().shuffle(dataset_size).batch(batch_size)

        model = deterministic.wide_resnet(input_shape=input_shape,
                                          depth=10,
                                          width_multiplier=1,
                                          num_classes=num_classes,
                                          l2=0.,
                                          version=2)
        model.compile('adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(
                          from_logits=True))
        history = model.fit(dataset,
                            steps_per_epoch=dataset_size // batch_size,
                            epochs=2)

        loss_history = history.history['loss']
        self.assertAllGreaterEqual(loss_history, 0.)
Esempio n. 2
0
def main(argv):
    del argv  # unused arg
    if not FLAGS.use_gpu:
        raise ValueError('Only GPU is currently supported.')
    if FLAGS.num_cores > 1:
        raise ValueError('Only a single accelerator is currently supported.')
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)
    tf.io.gfile.makedirs(FLAGS.output_dir)

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    dataset_input_fn = utils.load_input_fn(
        split=tfds.Split.TEST,
        name=FLAGS.dataset,
        batch_size=FLAGS.per_core_batch_size,
        use_bfloat16=FLAGS.use_bfloat16)
    test_datasets = {'clean': dataset_input_fn()}
    corruption_types, max_intensity = utils.load_corrupted_test_info(
        FLAGS.dataset)
    for name in corruption_types:
        for intensity in range(1, max_intensity + 1):
            dataset_name = '{0}_{1}'.format(name, intensity)
            if FLAGS.dataset == 'cifar10':
                load_c_dataset = utils.load_cifar10_c_input_fn
            else:
                load_c_dataset = functools.partial(
                    utils.load_cifar100_c_input_fn, path=FLAGS.cifar100_c_path)
            corrupted_input_fn = load_c_dataset(
                corruption_name=name,
                corruption_intensity=intensity,
                batch_size=FLAGS.per_core_batch_size,
                use_bfloat16=FLAGS.use_bfloat16)
            test_datasets[dataset_name] = corrupted_input_fn()

    model = deterministic.wide_resnet(
        input_shape=ds_info.features['image'].shape,
        depth=28,
        width_multiplier=10,
        num_classes=num_classes,
        l2=0.,
        version=2)
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = tf.io.gfile.glob(
        os.path.join(FLAGS.checkpoint_dir, '**/*.index'))
    ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))
    checkpoint = tf.train.Checkpoint(model=model)

    # Write model predictions to files.
    num_datasets = len(test_datasets)
    for m, ensemble_filename in enumerate(ensemble_filenames):
        checkpoint.restore(ensemble_filename)
        for n, (name, test_dataset) in enumerate(test_datasets.items()):
            filename = '{dataset}_{member}.npy'.format(dataset=name, member=m)
            filename = os.path.join(FLAGS.output_dir, filename)
            if not tf.io.gfile.exists(filename):
                logits = []
                test_iterator = iter(test_dataset)
                for _ in range(steps_per_eval):
                    features, _ = next(test_iterator)  # pytype: disable=attribute-error
                    logits.append(model(features, training=False))

                logits = tf.concat(logits, axis=0)
                with tf.io.gfile.GFile(filename, 'w') as f:
                    np.save(f, logits.numpy())
            percent = (m * num_datasets +
                       (n + 1)) / (ensemble_size * num_datasets)
            message = (
                '{:.1%} completion for prediction: ensemble member {:d}/{:d}. '
                'Dataset {:d}/{:d}'.format(percent, m + 1, ensemble_size,
                                           n + 1, num_datasets))
            logging.info(message)

    metrics = {
        'test/negative_log_likelihood': tf.keras.metrics.Mean(),
        'test/gibbs_cross_entropy': tf.keras.metrics.Mean(),
        'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
        'test/ece':
        ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
    }
    corrupt_metrics = {}
    for name in test_datasets:
        corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean()
        corrupt_metrics['test/accuracy_{}'.format(name)] = (
            tf.keras.metrics.SparseCategoricalAccuracy())
        corrupt_metrics['test/ece_{}'.format(name)] = (
            ed.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins))

    # Evaluate model predictions.
    for n, (name, test_dataset) in enumerate(test_datasets.items()):
        logits_dataset = []
        for m in range(ensemble_size):
            filename = '{dataset}_{member}.npy'.format(dataset=name, member=m)
            filename = os.path.join(FLAGS.output_dir, filename)
            with tf.io.gfile.GFile(filename, 'rb') as f:
                logits_dataset.append(np.load(f))

        logits_dataset = tf.convert_to_tensor(logits_dataset)
        test_iterator = iter(test_dataset)
        for step in range(steps_per_eval):
            _, labels = next(test_iterator)  # pytype: disable=attribute-error
            logits = logits_dataset[:, (step * batch_size):((step + 1) *
                                                            batch_size)]
            labels = tf.cast(labels, tf.int32)
            negative_log_likelihood = tf.reduce_mean(
                ensemble_negative_log_likelihood(labels, logits))
            per_probs = tf.nn.softmax(logits)
            probs = tf.reduce_mean(per_probs, axis=0)
            if name == 'clean':
                gibbs_ce = tf.reduce_mean(gibbs_cross_entropy(labels, logits))
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/gibbs_cross_entropy'].update_state(gibbs_ce)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
            else:
                corrupt_metrics['test/nll_{}'.format(name)].update_state(
                    negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(name)].update_state(
                    labels, probs)
                corrupt_metrics['test/ece_{}'.format(name)].update_state(
                    labels, probs)

        message = (
            '{:.1%} completion for evaluation: dataset {:d}/{:d}'.format(
                (n + 1) / num_datasets, n + 1, num_datasets))
        logging.info(message)

    corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                      corruption_types,
                                                      max_intensity)
    total_results = {name: metric.result() for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    logging.info('Metrics: %s', total_results)
Esempio n. 3
0
def main(argv):
    del argv  # unused arg
    if FLAGS.num_cores > 1:
        raise ValueError('Only a single accelerator is currently supported.')
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)

    dataset_input_fn = utils.load_input_fn(tfds.Split.TEST,
                                           FLAGS.per_core_batch_size,
                                           name=FLAGS.dataset,
                                           use_bfloat16=False,
                                           normalize=True,
                                           drop_remainder=True,
                                           proportion=1.0)
    test_datasets = {'clean': dataset_input_fn()}

    ds_info = tfds.builder(FLAGS.dataset).info
    num_classes = ds_info.features['label'].num_classes
    model = deterministic.wide_resnet(
        input_shape=ds_info.features['image'].shape,
        depth=28,
        width_multiplier=10,
        num_classes=num_classes,
        l2=0.,
        version=2)
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = tf.io.gfile.glob(
        os.path.join(FLAGS.output_dir, '**/*.index'))
    ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))
    checkpoint = tf.train.Checkpoint(model=model)

    # Collect the logits output for each ensemble member and train/test data
    # point. We also collect the labels.
    # TODO(trandustin): Refactor data loader so you can get the full dataset in
    # memory without looping.
    logits_test = {'clean': []}
    labels_test = {'clean': []}
    corruption_types, max_intensity = utils.load_corrupted_test_info(
        FLAGS.dataset)
    for name in corruption_types:
        for intensity in range(1, max_intensity + 1):
            dataset_name = '{0}_{1}'.format(name, intensity)
            logits_test[dataset_name] = []
            labels_test[dataset_name] = []

            if FLAGS.dataset == 'cifar10':
                load_c_dataset = utils.load_cifar10_c_input_fn
            else:
                load_c_dataset = functools.partial(
                    utils.load_cifar100_c_input_fn, path=FLAGS.cifar100_c_path)
            corrupted_input_fn = load_c_dataset(
                corruption_name=name,
                corruption_intensity=intensity,
                batch_size=FLAGS.per_core_batch_size,
                use_bfloat16=False)
            test_datasets[dataset_name] = corrupted_input_fn()

    for m, ensemble_filename in enumerate(ensemble_filenames):
        checkpoint.restore(ensemble_filename)
        logging.info('Working on test data for ensemble member %s', m)
        for name, test_dataset in test_datasets.items():
            logits = []
            for features, labels in test_dataset:
                logits.append(model(features, training=False))
                if m == 0:
                    labels_test[name].append(labels)

            logits = tf.concat(logits, axis=0)
            logits_test[name].append(logits)
            if m == 0:
                labels_test[name] = tf.concat(labels_test[name], axis=0)
            logging.info('Finished testing on %s', format(name))

    metrics = {
        'test/ece':
        ed.metrics.ExpectedCalibrationError(num_classes=num_classes,
                                            num_bins=15)
    }
    corrupt_metrics = {}
    for name in test_datasets:
        corrupt_metrics['test/ece_{}'.format(
            name)] = ed.metrics.ExpectedCalibrationError(
                num_classes=num_classes, num_bins=15)
        corrupt_metrics['test/nll_{}'.format(name)] = tf.keras.metrics.Mean()
        corrupt_metrics['test/accuracy_{}'.format(
            name)] = tf.keras.metrics.Mean()

    for name, test_dataset in test_datasets.items():
        labels = labels_test[name]
        logits = logits_test[name]
        nll_test = ensemble_negative_log_likelihood(labels, logits)
        gibbs_ce_test = gibbs_cross_entropy(labels_test[name],
                                            logits_test[name])
        labels = tf.cast(labels, tf.int32)
        logits = tf.convert_to_tensor(logits)
        per_probs = tf.nn.softmax(logits)
        probs = tf.reduce_mean(per_probs, axis=0)
        accuracy = tf.keras.metrics.sparse_categorical_accuracy(labels, probs)
        if name == 'clean':
            metrics['test/negative_log_likelihood'] = tf.reduce_mean(nll_test)
            metrics['test/gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_test)
            metrics['test/accuracy'] = tf.reduce_mean(accuracy)
            metrics['test/ece'].update_state(labels, probs)
        else:
            corrupt_metrics['test/nll_{}'.format(name)].update_state(
                tf.reduce_mean(nll_test))
            corrupt_metrics['test/accuracy_{}'.format(name)].update_state(
                tf.reduce_mean(accuracy))
            corrupt_metrics['test/ece_{}'.format(name)].update_state(
                labels, probs)

    corrupt_results = {}
    corrupt_results = utils.aggregate_corrupt_metrics(corrupt_metrics,
                                                      corruption_types,
                                                      max_intensity)
    metrics['test/ece'] = metrics['test/ece'].result()
    total_results = {name: metric for name, metric in metrics.items()}
    total_results.update(corrupt_results)
    logging.info('Metrics: %s', total_results)
Esempio n. 4
0
def main(argv):
    del argv  # unused arg
    if FLAGS.num_cores > 1:
        raise ValueError('Only a single accelerator is currently supported.')
    tf.enable_v2_behavior()
    tf.random.set_seed(FLAGS.seed)

    # TODO(trandustin): Replace with load_distributed_dataset. Currently hangs.
    dataset_train = utils.load_dataset(tfds.Split.TRAIN, FLAGS.dataset)
    dataset_test = utils.load_dataset(tfds.Split.TEST, FLAGS.dataset)
    dataset_train = dataset_train.batch(FLAGS.per_core_batch_size)
    dataset_test = dataset_test.batch(FLAGS.per_core_batch_size)
    ds_info = tfds.builder(FLAGS.dataset).info

    model = deterministic.wide_resnet(
        input_shape=ds_info.features['image'].shape,
        depth=28,
        width_multiplier=10,
        num_classes=ds_info.features['label'].num_classes,
        l2=0.,
        version=2)
    logging.info('Model input shape: %s', model.input_shape)
    logging.info('Model output shape: %s', model.output_shape)
    logging.info('Model number of weights: %s', model.count_params())

    # Search for checkpoints from their index file; then remove the index suffix.
    ensemble_filenames = tf.io.gfile.glob(
        os.path.join(FLAGS.output_dir, '**/*.index'))
    ensemble_filenames = [filename[:-6] for filename in ensemble_filenames]
    ensemble_size = len(ensemble_filenames)
    logging.info('Ensemble size: %s', ensemble_size)
    logging.info('Ensemble number of weights: %s',
                 ensemble_size * model.count_params())
    logging.info('Ensemble filenames: %s', str(ensemble_filenames))
    checkpoint = tf.train.Checkpoint(model=model)

    # Collect the logits output for each ensemble member and train/test data
    # point. We also collect the labels.
    # TODO(trandustin): Refactor data loader so you can get the full dataset in
    # memory without looping.
    logits_train = []
    logits_test = []
    labels_train = []
    labels_test = []
    start_time = time.time()
    for m, ensemble_filename in enumerate(ensemble_filenames):
        checkpoint.restore(ensemble_filename)
        logits = []
        logging.info('Working on training data for ensemble member %s', m)
        for features, labels in dataset_train:
            logits.append(model(features, training=False))
            if m == 0:
                labels_train.append(labels)

        logits = tf.concat(logits, axis=0)
        logits_train.append(logits)
        if m == 0:
            labels_train = tf.concat(labels_train, axis=0)

        logging.info('Working on test data for ensemble member %s', m)
        logits = []
        for features, labels in dataset_test:
            logits.append(model(features, training=False))
            if m == 0:
                labels_test.append(labels)

        logits = tf.concat(logits, axis=0)
        logits_test.append(logits)
        if m == 0:
            labels_test = tf.concat(labels_test, axis=0)

        batch_size = FLAGS.per_core_batch_size
        steps_per_epoch = ds_info.splits['train'].num_examples // batch_size
        steps_per_eval = ds_info.splits['test'].num_examples // batch_size
        current_step = (steps_per_epoch + steps_per_eval) * (m + 1)
        max_steps = (steps_per_epoch + steps_per_eval) * ensemble_size
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = (
            '{:.1%} completion: ensemble member {:d}/{:d}. {:.1f} steps/s. '
            'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                (m + 1) / ensemble_size, m + 1, ensemble_size, steps_per_sec,
                eta_seconds / 60, time_elapsed / 60))

    metrics = {}

    # Compute the ensemble's NLL and Gibbs cross entropy for each data point.
    # Then average over the dataset.
    nll_train = ensemble_negative_log_likelihood(labels_train, logits_train)
    nll_test = ensemble_negative_log_likelihood(labels_test, logits_test)
    gibbs_ce_train = gibbs_cross_entropy(labels_train, logits_train)
    gibbs_ce_test = gibbs_cross_entropy(labels_test, logits_test)
    metrics['train_negative_log_likelihood'] = tf.reduce_mean(nll_train)
    metrics['test_negative_log_likelihood'] = tf.reduce_mean(nll_test)
    metrics['train_gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_train)
    metrics['test_gibbs_cross_entropy'] = tf.reduce_mean(gibbs_ce_test)

    # Given the per-element logits tensor of shape [ensemble_size, dataset_size,
    # num_classes], average over the ensemble members' probabilities. Then
    # compute accuracy and average over the dataset.
    probs_train = tf.reduce_mean(tf.nn.softmax(logits_train), axis=0)
    probs_test = tf.reduce_mean(tf.nn.softmax(logits_test), axis=0)
    accuracy_train = tf.keras.metrics.sparse_categorical_accuracy(
        labels_train, probs_train)
    accuracy_test = tf.keras.metrics.sparse_categorical_accuracy(
        labels_test, probs_test)
    metrics['train_accuracy'] = tf.reduce_mean(accuracy_train)
    metrics['test_accuracy'] = tf.reduce_mean(accuracy_test)
    logging.info('Metrics: %s', metrics)