def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)

    # Set seeds
    tf.random.set_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    torch.manual_seed(FLAGS.seed)

    # Resolve CUDA device(s)
    if FLAGS.use_gpu and torch.cuda.is_available():
        print('Running model with CUDA.')
        device = 'cuda:0'
    else:
        print('Running model on CPU.')
        device = 'cpu'

    train_batch_size = FLAGS.train_batch_size
    eval_batch_size = FLAGS.eval_batch_size // FLAGS.num_dropout_samples_eval

    # As per the Kaggle challenge, we have split sizes:
    # train: 35,126
    # validation: 10,906
    # test: 42,670
    ds_info = tfds.builder('diabetic_retinopathy_detection').info
    steps_per_epoch = ds_info.splits['train'].num_examples // train_batch_size
    steps_per_validation_eval = (ds_info.splits['validation'].num_examples //
                                 eval_batch_size)
    steps_per_test_eval = ds_info.splits['test'].num_examples // eval_batch_size

    data_dir = FLAGS.data_dir

    dataset_train_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                            split='train',
                                            data_dir=data_dir)
    dataset_train = dataset_train_builder.load(batch_size=train_batch_size)

    dataset_validation_builder = ub.datasets.get(
        'diabetic_retinopathy_detection',
        split='validation',
        data_dir=data_dir,
        is_training=not FLAGS.use_validation)
    validation_batch_size = (eval_batch_size
                             if FLAGS.use_validation else train_batch_size)
    dataset_validation = dataset_validation_builder.load(
        batch_size=validation_batch_size)
    if not FLAGS.use_validation:
        # Note that this will not create any mixed batches of train and validation
        # images.
        dataset_train = dataset_train.concatenate(dataset_validation)

    dataset_test_builder = ub.datasets.get('diabetic_retinopathy_detection',
                                           split='test',
                                           data_dir=data_dir)
    dataset_test = dataset_test_builder.load(batch_size=eval_batch_size)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    # MC Dropout ResNet50 based on PyTorch Vision implementation
    logging.info('Building Torch ResNet-50 MC Dropout model.')
    model = ub.models.resnet50_dropout_torch(num_classes=1,
                                             dropout_rate=FLAGS.dropout_rate)
    logging.info('Model number of weights: %s',
                 torch_utils.count_parameters(model))

    # Linearly scale learning rate and the decay epochs by vanilla settings.
    base_lr = FLAGS.base_learning_rate
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=base_lr,
                                momentum=1.0 - FLAGS.one_minus_momentum,
                                nesterov=True)
    steps_to_lr_peak = int(steps_per_epoch * FLAGS.lr_warmup_epochs)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, steps_to_lr_peak, T_mult=2)

    model = model.to(device)

    metrics = utils.get_diabetic_retinopathy_base_metrics(
        use_tpu=False,
        num_bins=FLAGS.num_bins,
        use_validation=FLAGS.use_validation)

    # Define additional metrics that would fail in a TF TPU implementation.
    metrics.update(
        utils.get_diabetic_retinopathy_cpu_metrics(
            use_validation=FLAGS.use_validation))

    # Initialize loss function based on class reweighting setting
    loss_fn = torch.nn.BCELoss()
    sigmoid = torch.nn.Sigmoid()
    max_steps = steps_per_epoch * FLAGS.train_epochs
    image_h = 512
    image_w = 512

    def run_train_epoch(iterator):
        def train_step(inputs):
            images = inputs['features']
            labels = inputs['labels']
            images = torch.from_numpy(images._numpy()).view(
                train_batch_size,
                3,  # pylint: disable=protected-access
                image_h,
                image_w).to(device)
            labels = torch.from_numpy(labels._numpy()).to(device).float()  # pylint: disable=protected-access

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward
            logits = model(images)
            probs = sigmoid(logits).squeeze(-1)

            # Add L2 regularization loss to NLL
            negative_log_likelihood = loss_fn(probs, labels)
            l2_loss = sum(p.pow(2.0).sum() for p in model.parameters())
            loss = negative_log_likelihood + (FLAGS.l2 * l2_loss)

            # Backward/optimizer
            loss.backward()
            optimizer.step()

            # Convert to NumPy for metrics updates
            loss = loss.detach()
            negative_log_likelihood = negative_log_likelihood.detach()
            labels = labels.detach()
            probs = probs.detach()

            if device != 'cpu':
                loss = loss.cpu()
                negative_log_likelihood = negative_log_likelihood.cpu()
                labels = labels.cpu()
                probs = probs.cpu()

            loss = loss.numpy()
            negative_log_likelihood = negative_log_likelihood.numpy()
            labels = labels.numpy()
            probs = probs.numpy()

            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, probs)
            metrics['train/auprc'].update_state(labels, probs)
            metrics['train/auroc'].update_state(labels, probs)
            metrics['train/ece'].add_batch(probs, label=labels)

        for step in range(steps_per_epoch):
            train_step(next(iterator))

            if step % 100 == 0:
                current_step = (epoch + 1) * step
                time_elapsed = time.time() - start_time
                steps_per_sec = float(current_step) / time_elapsed
                eta_seconds = (max_steps - current_step
                               ) / steps_per_sec if steps_per_sec else 0
                message = (
                    '{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                    'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                        current_step / max_steps, epoch + 1,
                        FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                        time_elapsed / 60))
                logging.info(message)

    def run_eval_epoch(iterator, dataset_split, num_steps):
        def eval_step(inputs, model):
            images = inputs['features']
            labels = inputs['labels']
            images = torch.from_numpy(images._numpy()).view(
                eval_batch_size,
                3,  # pylint: disable=protected-access
                image_h,
                image_w).to(device)
            labels = torch.from_numpy(
                labels._numpy()).to(device).float().unsqueeze(-1)  # pylint: disable=protected-access

            with torch.no_grad():
                logits = torch.stack([
                    model(images)
                    for _ in range(FLAGS.num_dropout_samples_eval)
                ],
                                     dim=-1)

            # Logits dimension is (batch_size, 1, num_dropout_samples).
            logits = logits.squeeze()

            # It is now (batch_size, num_dropout_samples).
            probs = sigmoid(logits)

            # labels_tiled shape is (batch_size, num_dropout_samples).
            labels_tiled = torch.tile(labels,
                                      (1, FLAGS.num_dropout_samples_eval))

            log_likelihoods = -loss_fn(probs, labels_tiled)
            negative_log_likelihood = torch.mean(
                -torch.logsumexp(log_likelihoods, dim=-1) +
                torch.log(torch.tensor(float(FLAGS.num_dropout_samples_eval))))

            probs = torch.mean(probs, dim=-1)

            # Convert to NumPy for metrics updates
            negative_log_likelihood = negative_log_likelihood.detach()
            labels = labels.detach()
            probs = probs.detach()

            if device != 'cpu':
                negative_log_likelihood = negative_log_likelihood.cpu()
                labels = labels.cpu()
                probs = probs.cpu()

            negative_log_likelihood = negative_log_likelihood.numpy()
            labels = labels.numpy()
            probs = probs.numpy()

            metrics[dataset_split + '/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics[dataset_split + '/accuracy'].update_state(labels, probs)
            metrics[dataset_split + '/auprc'].update_state(labels, probs)
            metrics[dataset_split + '/auroc'].update_state(labels, probs)
            metrics[dataset_split + '/ece'].add_batch(probs, label=labels)

        for _ in range(num_steps):
            eval_step(next(iterator), model=model)

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
    start_time = time.time()
    initial_epoch = 0
    train_iterator = iter(dataset_train)
    model.train()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch + 1)

        run_train_epoch(train_iterator)

        if FLAGS.use_validation:
            validation_iterator = iter(dataset_validation)
            logging.info('Starting to run validation eval at epoch: %s',
                         epoch + 1)
            run_eval_epoch(validation_iterator, 'validation',
                           steps_per_validation_eval)

        test_iterator = iter(dataset_test)
        logging.info('Starting to run test eval at epoch: %s', epoch + 1)
        test_start_time = time.time()
        run_eval_epoch(test_iterator, 'test', steps_per_test_eval)
        ms_per_example = (time.time() -
                          test_start_time) * 1e6 / eval_batch_size
        metrics['test/ms_per_example'].update_state(ms_per_example)

        # Step scheduler
        scheduler.step()

        # Log and write to summary the epoch metrics
        utils.log_epoch_metrics(metrics=metrics, use_tpu=False)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):

            checkpoint_path = os.path.join(FLAGS.output_dir,
                                           f'model_{epoch + 1}.pt')
            torch_utils.checkpoint_torch_model(model=model,
                                               optimizer=optimizer,
                                               epoch=epoch + 1,
                                               checkpoint_path=checkpoint_path)
            logging.info('Saved Torch checkpoint to %s', checkpoint_path)

    final_checkpoint_path = os.path.join(FLAGS.output_dir,
                                         f'model_{FLAGS.train_epochs}.pt')
    torch_utils.checkpoint_torch_model(model=model,
                                       optimizer=optimizer,
                                       epoch=FLAGS.train_epochs,
                                       checkpoint_path=final_checkpoint_path)
    logging.info('Saved last checkpoint to %s', final_checkpoint_path)

    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'dropout_rate': FLAGS.dropout_rate,
            'l2': FLAGS.l2,
            'lr_warmup_epochs': FLAGS.lr_warmup_epochs
        })
Beispiel #2
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)

    train_input_fn = utils.load_input_fn(
        split=tfds.Split.TRAIN,
        name=FLAGS.dataset,
        batch_size=FLAGS.per_core_batch_size // FLAGS.batch_repetitions,
        use_bfloat16=FLAGS.use_bfloat16)
    clean_test_input_fn = utils.load_input_fn(
        split=tfds.Split.TEST,
        name=FLAGS.dataset,
        batch_size=FLAGS.per_core_batch_size,
        use_bfloat16=FLAGS.use_bfloat16)
    train_dataset = strategy.experimental_distribute_datasets_from_function(
        train_input_fn)
    test_datasets = {
        'clean':
        strategy.experimental_distribute_datasets_from_function(
            clean_test_input_fn),
    }
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar10':
            load_c_input_fn = utils.load_cifar10_c_input_fn
        else:
            load_c_input_fn = functools.partial(utils.load_cifar100_c_input_fn,
                                                path=FLAGS.cifar100_c_path)
        corruption_types, max_intensity = utils.load_corrupted_test_info(
            FLAGS.dataset)
        for corruption in corruption_types:
            for intensity in range(1, max_intensity + 1):
                input_fn = load_c_input_fn(
                    corruption_name=corruption,
                    corruption_intensity=intensity,
                    batch_size=FLAGS.per_core_batch_size,
                    use_bfloat16=FLAGS.use_bfloat16)
                test_datasets['{0}_{1}'.format(corruption, intensity)] = (
                    strategy.experimental_distribute_datasets_from_function(
                        input_fn))

    ds_info = tfds.builder(FLAGS.dataset).info
    train_batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores // FLAGS.batch_repetitions
    test_batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    train_dataset_size = ds_info.splits['train'].num_examples
    steps_per_epoch = train_dataset_size // train_batch_size
    steps_per_eval = ds_info.splits['test'].num_examples // test_batch_size
    num_classes = ds_info.features['label'].num_classes

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building Keras model')
        model = cifar_model.wide_resnet(
            input_shape=[FLAGS.ensemble_size] +
            list(ds_info.features['image'].shape),
            depth=28,
            width_multiplier=FLAGS.width_multiplier,
            num_classes=num_classes,
            ensemble_size=FLAGS.ensemble_size)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * train_batch_size / 128
        lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                           for start_epoch_str in FLAGS.lr_decay_epochs]
        lr_schedule = utils.LearningRateSchedule(steps_per_epoch, base_lr,
                                                 FLAGS.lr_decay_ratio,
                                                 lr_decay_epochs,
                                                 FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=0.9,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood': tf.keras.metrics.Mean(),
            'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss': tf.keras.metrics.Mean(),
            'train/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood': tf.keras.metrics.Mean(),
            'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        }
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, max_intensity + 1):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        um.ExpectedCalibrationError(num_bins=FLAGS.num_bins))

        for i in range(FLAGS.ensemble_size):
            metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean()
            metrics['test/accuracy_member_{}'.format(i)] = (
                tf.keras.metrics.SparseCategoricalAccuracy())
        test_diversity = {
            'test/disagreement': tf.keras.metrics.Mean(),
            'test/average_kl': tf.keras.metrics.Mean(),
            'test/cosine_similarity': tf.keras.metrics.Mean(),
        }

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            batch_size = tf.shape(images)[0]

            main_shuffle = tf.random.shuffle(
                tf.tile(tf.range(batch_size), [FLAGS.batch_repetitions]))
            to_shuffle = tf.cast(
                tf.cast(tf.shape(main_shuffle)[0], tf.float32) *
                (1. - FLAGS.input_repetition_probability), tf.int32)
            shuffle_indices = [
                tf.concat([
                    tf.random.shuffle(main_shuffle[:to_shuffle]),
                    main_shuffle[to_shuffle:]
                ],
                          axis=0) for _ in range(FLAGS.ensemble_size)
            ]
            images = tf.stack([
                tf.gather(images, indices, axis=0)
                for indices in shuffle_indices
            ],
                              axis=1)
            labels = tf.stack([
                tf.gather(labels, indices, axis=0)
                for indices in shuffle_indices
            ],
                              axis=1)

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)

                negative_log_likelihood = tf.reduce_mean(
                    tf.reduce_sum(
                        tf.keras.losses.sparse_categorical_crossentropy(
                            labels, logits, from_logits=True),
                        axis=1))
                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the BN parameters and bias terms.
                    if ('kernel' in var.name or 'batch_norm' in var.name
                            or 'bias' in var.name):
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(tf.reshape(logits, [-1, num_classes]))
            flat_labels = tf.reshape(labels, [-1])
            metrics['train/ece'].update_state(flat_labels, probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(flat_labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            images = tf.tile(tf.expand_dims(images, 1),
                             [1, FLAGS.ensemble_size, 1, 1, 1])
            logits = model(images, training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.softmax(logits)

            if dataset_name == 'clean':
                per_probs = tf.transpose(probs, perm=[1, 0, 2])
                diversity_results = um.average_pairwise_diversity(
                    per_probs, FLAGS.ensemble_size)
                for k, v in diversity_results.items():
                    test_diversity['test/' + k].update_state(v)

            for i in range(FLAGS.ensemble_size):
                member_probs = probs[:, i]
                member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, member_probs)
                metrics['test/nll_member_{}'.format(i)].update_state(
                    member_loss)
                metrics['test/accuracy_member_{}'.format(i)].update_state(
                    labels, member_probs)

            # Negative log marginal likelihood computed in a numerically-stable way.
            labels_tiled = tf.tile(tf.expand_dims(labels, 1),
                                   [1, FLAGS.ensemble_size])
            log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy(
                labels_tiled, logits, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[1]) +
                tf.math.log(float(FLAGS.ensemble_size)))
            probs = tf.math.reduce_mean(probs, axis=1)  # marginalize

            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(
                    dataset_name)].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)

        for step in range(steps_per_epoch):
            train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * (FLAGS.train_epochs)
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_start_time = time.time()
                test_step(test_iterator, dataset_name)
                ms_per_example = (time.time() -
                                  test_start_time) * 1e6 / test_batch_size
                metrics['test/ms_per_example'].update_state(ms_per_example)
            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types, max_intensity)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        for i in range(FLAGS.ensemble_size):
            logging.info(
                'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
                metrics['test/nll_member_{}'.format(i)].result(),
                metrics['test/accuracy_member_{}'.format(i)].result() * 100)

        metrics.update(test_diversity)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
Beispiel #3
0
def run(flags_obj, datasets_override=None, strategy_override=None):
    """Run MNIST model training and eval loop using native Keras APIs.

  Args:
    flags_obj: An object containing parsed flag values.
    datasets_override: A pair of `tf.data.Dataset` objects to train the model,
                       representing the train and test sets.
    strategy_override: A `tf.distribute.Strategy` object to use for model.

  Returns:
    Dictionary of training and eval stats.
  """
    strategy = strategy_override or distribute_utils.get_distribution_strategy(
        distribution_strategy=flags_obj.distribution_strategy,
        num_gpus=flags_obj.num_gpus,
        tpu_address=flags_obj.tpu)

    strategy_scope = distribute_utils.get_strategy_scope(strategy)

    mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
    if flags_obj.download:
        mnist.download_and_prepare()

    mnist_train, mnist_test = datasets_override or mnist.as_dataset(
        split=['train', 'test'],
        decoders={'image': decode_image()},  # pylint: disable=no-value-for-parameter
        as_supervised=True)
    train_input_dataset = mnist_train.cache().repeat().shuffle(
        buffer_size=50000).batch(flags_obj.batch_size)
    eval_input_dataset = mnist_test.cache().repeat().batch(
        flags_obj.batch_size)

    with strategy_scope:
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            0.05, decay_steps=100000, decay_rate=0.96)
        optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

        model = build_model()
        model.compile(optimizer=optimizer,
                      loss='sparse_categorical_crossentropy',
                      metrics=['sparse_categorical_accuracy'])

    num_train_examples = mnist.info.splits['train'].num_examples
    train_steps = num_train_examples // flags_obj.batch_size
    train_epochs = flags_obj.train_epochs

    ckpt_full_path = os.path.join(flags_obj.model_dir,
                                  'model.ckpt-{epoch:04d}')
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
                                           save_weights_only=True),
        tf.keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir),
    ]

    num_eval_examples = mnist.info.splits['test'].num_examples
    num_eval_steps = num_eval_examples // flags_obj.batch_size

    history = model.fit(train_input_dataset,
                        epochs=train_epochs,
                        steps_per_epoch=train_steps,
                        callbacks=callbacks,
                        validation_steps=num_eval_steps,
                        validation_data=eval_input_dataset,
                        validation_freq=flags_obj.epochs_between_evals)

    export_path = os.path.join(flags_obj.model_dir, 'saved_model')
    model.save(export_path, include_optimizer=False)

    eval_output = model.evaluate(eval_input_dataset,
                                 steps=num_eval_steps,
                                 verbose=2)

    stats = common.build_stats(history, eval_output, callbacks)
    return stats
from config import Config

output_dir = Config['output_dir']
en_vocab_file = os.path.join(output_dir, 'en_vocab')
zh_vocab_file = os.path.join(output_dir, 'zh_vocab')
checkpoint_path = os.path.join(output_dir, 'checkpoints')
log_dir = os.path.join(output_dir, 'logs')

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

ds_config = tfds.translate.wmt.WmtConfig(
    version='1.0.0',
    language_pair=('zh', 'en'),
    subsets={tfds.Split.TRAIN: ['newscommentary_v14']})
builder = tfds.builder('wmt_translate', config=ds_config)
builder.download_and_prepare()
train_examples, val_examples = builder.as_dataset(
    split=['train[:30%]', 'train[30%:31%]'], as_supervised=True)

print('-' * 50)
try:
    subword_encoder_en = tfds.features.text.SubwordTextEncoder.load_from_file(
        en_vocab_file)
    print(f'Load builded corpus: {en_vocab_file}')
except:
    print(f'Build corpus: {en_vocab_file}')
    subword_encoder_en = tfds.features.text.SubwordTextEncoder.build_from_corpus(
        (en.numpy() for en, _ in train_examples), target_vocab_size=2**13)
    subword_encoder_en.save_to_file(en_vocab_file)
Beispiel #5
0
def fetch_data_via_tf_datasets(dataset_name):
    builder = tfds.builder(name=dataset_name)
    builder.download_and_prepare()
    data = builder.as_dataset(shuffle_files=False)
    return data
Beispiel #6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Enable training summary.
    if FLAGS.train_summary_steps > 0:
        tf.config.set_soft_device_placement(True)

    builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir)
    builder.download_and_prepare()
    num_train_examples = builder.info.splits[FLAGS.train_split].num_examples
    num_eval_examples = builder.info.splits[FLAGS.eval_split].num_examples
    num_classes = builder.info.features['label'].num_classes

    train_steps = model_util.get_train_steps(num_train_examples)
    eval_steps = int(math.ceil(num_eval_examples / FLAGS.eval_batch_size))
    epoch_steps = int(round(num_train_examples / FLAGS.train_batch_size))

    resnet.BATCH_NORM_DECAY = FLAGS.batch_norm_decay
    model = resnet.resnet_v1(resnet_depth=FLAGS.resnet_depth,
                             width_multiplier=FLAGS.width_multiplier,
                             cifar_stem=FLAGS.image_size <= 32)

    checkpoint_steps = (FLAGS.checkpoint_steps
                        or (FLAGS.checkpoint_epochs * epoch_steps))

    cluster = None
    if FLAGS.use_tpu and FLAGS.master is None:
        if FLAGS.tpu_name:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
                FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
        else:
            cluster = tf.distribute.cluster_resolver.TPUClusterResolver()
            tf.config.experimental_connect_to_cluster(cluster)
            tf.tpu.experimental.initialize_tpu_system(cluster)

    default_eval_mode = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V1
    sliced_eval_mode = tf.estimator.tpu.InputPipelineConfig.SLICED
    run_config = tf.estimator.tpu.RunConfig(
        tpu_config=tf.estimator.tpu.TPUConfig(
            iterations_per_loop=checkpoint_steps,
            eval_training_input_configuration=sliced_eval_mode
            if FLAGS.use_tpu else default_eval_mode),
        model_dir=FLAGS.model_dir,
        save_summary_steps=checkpoint_steps,
        save_checkpoints_steps=checkpoint_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max,
        master=FLAGS.master,
        cluster=cluster)
    estimator = tf.estimator.tpu.TPUEstimator(
        model_lib.build_model_fn(model, num_classes, num_train_examples),
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        use_tpu=FLAGS.use_tpu)

    if FLAGS.mode == 'eval':
        for ckpt in tf.train.checkpoints_iterator(run_config.model_dir,
                                                  min_interval_secs=15):
            try:
                result = perform_evaluation(estimator=estimator,
                                            input_fn=data_lib.build_input_fn(
                                                builder, False),
                                            eval_steps=eval_steps,
                                            model=model,
                                            num_classes=num_classes,
                                            checkpoint_path=ckpt)
            except tf.errors.NotFoundError:
                continue
            if result['global_step'] >= train_steps:
                return
    else:
        estimator.train(data_lib.build_input_fn(builder, True),
                        max_steps=train_steps)
        if FLAGS.mode == 'train_then_eval':
            perform_evaluation(estimator=estimator,
                               input_fn=data_lib.build_input_fn(
                                   builder, False),
                               eval_steps=eval_steps,
                               model=model,
                               num_classes=num_classes)
Beispiel #7
0
def get_uniform_size_builder(num_points=1024):
    if num_points != 1024:
        raise NotImplementedError()
    return tfds.builder('modelnet40/cloud1024')
Beispiel #8
0
    def __init__(self,
                 params: cfg.DataConfig,
                 dataset_fn=tf.data.TFRecordDataset,
                 decoder_fn: Optional[Callable[..., Any]] = None,
                 combine_fn: Optional[Callable[..., Any]] = None,
                 sample_fn: Optional[Callable[..., Any]] = None,
                 parser_fn: Optional[Callable[..., Any]] = None,
                 transform_and_batch_fn: Optional[Callable[
                     [tf.data.Dataset, Optional[tf.distribute.InputContext]],
                     tf.data.Dataset]] = None,
                 postprocess_fn: Optional[Callable[..., Any]] = None):
        """Initializes an InputReader instance.

    Args:
      params: A config_definitions.DataConfig object.
      dataset_fn: A `tf.data.Dataset` that consumes the input files. For
        example, it can be `tf.data.TFRecordDataset`.
      decoder_fn: An optional `callable` that takes the serialized data string
        and decodes them into the raw tensor dictionary.
      combine_fn: An optional `callable` that takes a dictionarty of
        `tf.data.Dataset` objects as input and outputs a combined dataset. It
        will be executed after the decoder_fn and before the sample_fn.
      sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
        input and outputs the transformed dataset. It performs sampling on the
        decoded raw tensors dict before the parser_fn.
      parser_fn: An optional `callable` that takes the decoded raw tensors dict
        and parse them into a dictionary of tensors that can be consumed by the
        model. It will be executed after decoder_fn.
      transform_and_batch_fn: An optional `callable` that takes a
        `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
        input, and returns a `tf.data.Dataset` object. It will be executed after
        `parser_fn` to transform and batch the dataset; if None, after
        `parser_fn` is executed, the dataset will be batched into per-replica
        batch size.
      postprocess_fn: A optional `callable` that processes batched tensors. It
        will be executed after batching.
    """
        if params.input_path and params.tfds_name:
            raise ValueError(
                'At most one of `input_path` and `tfds_name` can be '
                'specified, but got %s and %s.' %
                (params.input_path, params.tfds_name))

        if isinstance(params.input_path,
                      cfg.base_config.Config) and combine_fn is None:
            raise ValueError(
                'A `combine_fn` is required if the `input_path` is a dictionary.'
            )

        self._tfds_builder = None
        self._matched_files = None
        if not params.input_path:
            # Read dataset from TFDS.
            if not params.tfds_split:
                raise ValueError(
                    '`tfds_name` is %s, but `tfds_split` is not specified.' %
                    params.tfds_name)
            self._tfds_builder = tfds.builder(params.tfds_name,
                                              data_dir=params.tfds_data_dir)
        else:
            self._matched_files = self.get_files(params.input_path)

        self._global_batch_size = params.global_batch_size
        self._is_training = params.is_training
        self._drop_remainder = params.drop_remainder
        self._shuffle_buffer_size = params.shuffle_buffer_size
        self._cache = params.cache
        self._cycle_length = params.cycle_length
        self._block_length = params.block_length
        self._deterministic = params.deterministic
        self._sharding = params.sharding
        self._tfds_split = params.tfds_split
        self._tfds_as_supervised = params.tfds_as_supervised
        self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature

        self._dataset_fn = dataset_fn
        self._decoder_fn = decoder_fn
        self._combine_fn = combine_fn
        self._sample_fn = sample_fn
        self._parser_fn = parser_fn
        self._transform_and_batch_fn = transform_and_batch_fn
        self._postprocess_fn = postprocess_fn
        self._seed = params.seed

        # When tf.data service is enabled, each data service worker should get
        # different random seeds. Thus, we set `seed` to None.
        # Sharding should also be disabled because tf data service handles how
        # each worker shard data with `processing_mode` in distribute method.
        if params.enable_tf_data_service:
            self._seed = None
            self._sharding = False

        self._enable_tf_data_service = (params.enable_tf_data_service
                                        and params.tf_data_service_address)
        self._tf_data_service_address = params.tf_data_service_address
        if self._enable_tf_data_service:
            # Add a random seed as the tf.data service job name suffix, so tf.data
            # service doesn't reuse the previous state if TPU worker gets preempted.
            # It's necessary to add global batch size into the tf data service job
            # name because when tuning batch size with vizier and tf data service is
            # also enable, the tf data servce job name should be different for
            # different vizier trials since once batch size is changed, from the
            # tf.data perspective, the dataset is a different instance, and a
            # different job name should be used for tf data service. Otherwise, the
            # model would read tensors from the incorrect tf data service job, which
            # would causes dimension mismatch on the batch size dimension.
            self._tf_data_service_job_name = (
                f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
                f'{self.static_randnum}')
            self._enable_round_robin_tf_data_service = params.get(
                'enable_round_robin_tf_data_service', False)
def main(argv):
    del argv

    builder = tfds.builder('imagenet2012', version='5.1.0')
    decoders = {'image': tfds.decode.SkipDecoding()}

    read_config = tfds.ReadConfig(interleave_cycle_length=96,
                                  interleave_block_length=2)

    train_dataset_size = builder.info.splits[tfds.Split.TRAIN].num_examples
    train_split = tfds.Split.TRAIN
    if FLAGS.subsample:
        train_dataset_size = int(round(train_dataset_size * FLAGS.subsample))
        train_split = tfds.core.ReadInstruction(train_split,
                                                to=FLAGS.subsample * 100,
                                                unit='%')
    train_dataset = builder.as_dataset(train_split,
                                       decoders=decoders,
                                       shuffle_files=False,
                                       read_config=read_config).cache()
    train_dataset = train_dataset.shuffle(train_dataset_size).repeat()
    train_dataset = train_dataset.map(
        functools.partial(preprocess_data, is_training=True),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    train_dataset = train_dataset.batch(FLAGS.batch_size, drop_remainder=True)
    train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

    test_dataset = builder.as_dataset(tfds.Split.VALIDATION, decoders=decoders)
    test_dataset = test_dataset.map(
        functools.partial(preprocess_data, is_training=False),
        num_parallel_calls=tf.data.experimental.AUTOTUNE).cache()
    test_dataset = test_dataset.batch(FLAGS.batch_size)
    test_dataset = test_dataset.prefetch(tf.data.experimental.AUTOTUNE)
    test_dataset_size = builder.info.splits[tfds.Split.VALIDATION].num_examples

    steps_per_epoch = train_dataset_size // FLAGS.batch_size
    steps_between_evals = int(FLAGS.epochs_between_evals * steps_per_epoch)
    train_steps = FLAGS.epochs * steps_per_epoch
    eval_steps = ((test_dataset_size - 1) // FLAGS.batch_size) + 1

    model_dir_name = (
        '%s-depth-%s-width-%s-bs-%d-lr-%f-reg-%f-dropout-%f-aa-%s' % \
        (FLAGS.model, FLAGS.depth_multiplier, FLAGS.width_multiplier,
         FLAGS.batch_size, FLAGS.learning_rate, FLAGS.weight_decay,
         FLAGS.dropout_rate, FLAGS.use_autoaugment))
    if FLAGS.copy > 0:
        model_dir_name += '-copy-%d' % FLAGS.copy
    experiment_dir = os.path.join(FLAGS.base_dir, model_dir_name)

    def model_optimizer_fn():
        schedule = tf.keras.experimental.CosineDecay(FLAGS.learning_rate,
                                                     train_steps)
        if FLAGS.model == 'efficientnet':
            config = efficientnet_model.ModelConfig.from_args(
                width_coefficient=FLAGS.width_multiplier,
                depth_coefficient=FLAGS.depth_multiplier,
                resolution=224,
                weight_decay=FLAGS.weight_decay,
                dropout_rate=FLAGS.dropout_rate)
            model = efficientnet_model.EfficientNet(config)
        elif FLAGS.model == 'resnet':
            model = alt_resnet.Resnet(
                block_fn=alt_resnet.BottleneckBlock,
                layers=[3, 4, int(round(FLAGS.depth_multiplier * 6)), 3],
                width_multipliers=[1, 1, 1, FLAGS.width_multiplier, 1],
                num_classes=1000,
                kernel_regularizer=tf.keras.regularizers.l2(
                    FLAGS.weight_decay))
        elif FLAGS.model == 'resnet_scale_all':
            model = alt_resnet.Resnet(
                block_fn=alt_resnet.BottleneckBlock,
                layers=[
                    int(round(FLAGS.depth_multiplier * x))
                    for x in [3, 4, 6, 3]
                ],
                width_multipliers=[1] + [FLAGS.width_multiplier] * 4,
                num_classes=1000,
                kernel_regularizer=tf.keras.regularizers.l2(
                    FLAGS.weight_decay))
        else:
            raise ValueError('Unknown model {}'.format(FLAGS.model))
        optimizer = tf.keras.optimizers.SGD(schedule, momentum=0.9)
        return model, optimizer

    train_lib.train(model_optimizer_fn=model_optimizer_fn,
                    train_steps=train_steps,
                    eval_steps=eval_steps,
                    steps_between_evals=steps_between_evals,
                    train_dataset=train_dataset,
                    test_dataset=test_dataset,
                    experiment_dir=experiment_dir)
Beispiel #10
0
seed = np.random.randint(0, 100000, 10000)
data = data[0][seed]
TIM = data
data = 1
TIM = np.reshape(TIM, [10000, 64, 64, 3]) / 255.
TIM1 = []
for i in range(len(TIM)):
    im = cv2.resize(TIM[i], dsize=(32, 32), interpolation=cv2.INTER_CUBIC)
    TIM1.append(im)
TIM1 = np.array(TIM1)
np.save("TIM", TIM1)
TIM1 = 1
#TIM = np.load("./TIM.npy")

# get LSUN data
data, info = tfds.load("lsun", with_info=True)
train_data = data['train']
builder = tfds.builder("lsun")
builder.download_and_prepare()
datasets = builder.as_dataset()
np_datasets = tfds.as_numpy(datasets)
np_datasets

LSUN = []
for example in np_datasets["train"]:
    image = example['image']
    res = cv2.resize(image, dsize=(32, 32), interpolation=cv2.INTER_CUBIC)
    LSUN.append(res)
LSUN = np.array(LSUN)
np.save("LSUN.npy", LSUN[:10000])
Beispiel #11
0
def _train_and_eval_dataset(dataset_name,
                            data_dir,
                            eval_holdout_size,
                            train_shuffle_files=True,
                            eval_shuffle_files=False):
    """Return train and evaluation datasets, feature info and supervised keys.

  Args:
    dataset_name: a string, the name of the dataset; if it starts with 't2t_'
      then we'll search T2T Problem registry for it, otherwise we assume it
      is a dataset from TFDS and load it from there.
    data_dir: directory where the data is located.
    eval_holdout_size: float from 0 to <1; if >0 use this much of training data
      for evaluation (instead of looking for a pre-specified VALIDATION split).
    train_shuffle_files: Boolean determining whether or not to shuffle the train
      files at startup. Set to False if you want data determinism.
    eval_shuffle_files: Boolean determining whether or not to shuffle the test
      files at startup. Set to False if you want data determinism.

  Returns:
    a 4-tuple consisting of:
     * the train tf.Dataset
     * the eval tf.Dataset
     * information about features: a python dictionary with feature names
         as keys and an object as value that provides .shape and .n_classes.
     * supervised_keys: information what's the input and what's the target,
         ie., a pair of lists with input and target feature names.
  """
    if dataset_name.startswith('t2t_'):
        return _train_and_eval_dataset_v1(dataset_name[4:], data_dir,
                                          train_shuffle_files,
                                          eval_shuffle_files)
    dataset_builder = tfds.builder(dataset_name, data_dir=data_dir)
    info = dataset_builder.info
    splits = dataset_builder.info.splits
    if tfds.Split.TRAIN not in splits:
        raise ValueError('To train we require a train split in the dataset.')
    train_split = tfds.Split.TRAIN
    if eval_holdout_size > 0:
        holdout_percentage = int(eval_holdout_size * 100.0)
        train_percentage = 100 - holdout_percentage
        train_split = f'train[:{train_percentage}%]'
        eval_split = f'train[{train_percentage}%:]'
    else:
        if tfds.Split.VALIDATION not in splits and 'test' not in splits:
            raise ValueError(
                'We require a validation or test split in the dataset.')
        eval_split = tfds.Split.VALIDATION
        if tfds.Split.VALIDATION not in splits:
            eval_split = tfds.Split.TEST
    train = tfds.load(name=dataset_name,
                      split=train_split,
                      data_dir=data_dir,
                      shuffle_files=train_shuffle_files)
    valid = tfds.load(name=dataset_name,
                      split=eval_split,
                      data_dir=data_dir,
                      shuffle_files=eval_shuffle_files)
    keys = None
    if info.supervised_keys:
        keys = ([info.supervised_keys[0]], [info.supervised_keys[1]])
    return train, valid, keys
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    ds_info = tfds.builder(FLAGS.dataset).info
    per_core_batch_size = FLAGS.per_core_batch_size // FLAGS.ensemble_size
    batch_size = per_core_batch_size * FLAGS.num_cores
    # Train_proportion is a float so need to convert steps_per_epoch to int.
    steps_per_epoch = int(
        (ds_info.splits['train'].num_examples * FLAGS.train_proportion) //
        batch_size)
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    train_dataset = utils.load_dataset(split=tfds.Split.TRAIN,
                                       name=FLAGS.dataset,
                                       batch_size=batch_size,
                                       use_bfloat16=FLAGS.use_bfloat16,
                                       proportion=FLAGS.train_proportion)
    clean_test_dataset = utils.load_dataset(split=tfds.Split.TEST,
                                            name=FLAGS.dataset,
                                            batch_size=batch_size,
                                            use_bfloat16=FLAGS.use_bfloat16)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_datasets = {
        'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
    }
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar10':
            load_c_dataset = utils.load_cifar10_c
        else:
            load_c_dataset = functools.partial(utils.load_cifar100_c,
                                               path=FLAGS.cifar100_c_path)
        corruption_types, max_intensity = utils.load_corrupted_test_info(
            FLAGS.dataset)
        for corruption in corruption_types:
            for intensity in range(1, max_intensity + 1):
                dataset = load_c_dataset(corruption_name=corruption,
                                         corruption_intensity=intensity,
                                         batch_size=batch_size,
                                         use_bfloat16=FLAGS.use_bfloat16)
                test_datasets['{0}_{1}'.format(corruption, intensity)] = (
                    strategy.experimental_distribute_dataset(dataset))

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building Keras model')
        model = ub.models.wide_resnet_sngp_be(
            input_shape=ds_info.features['image'].shape,
            batch_size=batch_size,
            depth=28,
            width_multiplier=10,
            num_classes=num_classes,
            ensemble_size=FLAGS.ensemble_size,
            random_sign_init=FLAGS.random_sign_init,
            l2=FLAGS.l2,
            use_gp_layer=FLAGS.use_gp_layer,
            gp_input_dim=FLAGS.gp_input_dim,
            gp_hidden_dim=FLAGS.gp_hidden_dim,
            gp_scale=FLAGS.gp_scale,
            gp_bias=FLAGS.gp_bias,
            gp_input_normalization=FLAGS.gp_input_normalization,
            gp_cov_discount_factor=FLAGS.gp_cov_discount_factor,
            gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty,
            use_spec_norm=FLAGS.use_spec_norm,
            spec_norm_iteration=FLAGS.spec_norm_iteration,
            spec_norm_bound=FLAGS.spec_norm_bound)

        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                           for start_epoch_str in FLAGS.lr_decay_epochs]
        lr_schedule = utils.LearningRateSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=0.9,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood': tf.keras.metrics.Mean(),
            'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss': tf.keras.metrics.Mean(),
            'train/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood': tf.keras.metrics.Mean(),
            'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/stddev': tf.keras.metrics.Mean(),
        }
        for i in range(FLAGS.ensemble_size):
            metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean()
            metrics['test/accuracy_member_{}'.format(i)] = (
                tf.keras.metrics.SparseCategoricalAccuracy())
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, max_intensity + 1):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        um.ExpectedCalibrationError(num_bins=FLAGS.num_bins))
                    corrupt_metrics['test/stddev_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
            labels = tf.tile(labels, [FLAGS.ensemble_size])

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if isinstance(logits, tuple):
                    # If model returns a tuple of (logits, covmat), extract logits
                    logits, _ = logits
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            # Separate learning rate implementation.
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = []
                for grad, var in zip(grads, model.trainable_variables):
                    # Apply different learning rate on the fast weight approximate
                    # posterior/prior parameters. This is excludes BN and slow weights,
                    # but pay caution to the naming scheme.
                    if ('batch_norm' not in var.name
                            and 'kernel' not in var.name):
                        grads_and_vars.append(
                            (grad * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grads_and_vars.append((grad, var))
                optimizer.apply_gradients(grads_and_vars)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].update_state(labels, probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs

            logits_list = []
            stddev_list = []

            for i in range(FLAGS.ensemble_size):
                logits = model(images, training=False)
                if isinstance(logits, tuple):
                    # If model returns a tuple of (logits, covmat), extract both
                    logits, covmat = logits
                else:
                    covmat = tf.eye(FLAGS.per_core_batch_size)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                logits = mean_field_logits(
                    logits,
                    covmat,
                    mean_field_factor=FLAGS.gp_mean_field_factor)
                stddev = tf.sqrt(tf.linalg.diag_part(covmat))

                stddev_list.append(stddev)
                logits_list.append(logits)

                member_probs = tf.nn.softmax(logits)
                member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, member_probs)
                metrics['test/nll_member_{}'.format(i)].update_state(
                    member_loss)
                metrics['test/accuracy_member_{}'.format(i)].update_state(
                    labels, member_probs)
            # Logits dimension is (num_samples, batch_size, num_classes).
            logits_list = tf.stack(logits_list, axis=0)
            stddev_list = tf.stack(stddev_list, axis=0)

            stddev = tf.reduce_mean(stddev_list, axis=0)
            probs_list = tf.nn.softmax(logits_list)
            probs = tf.reduce_mean(probs_list, axis=0)

            labels_broadcasted = tf.broadcast_to(
                labels, [FLAGS.ensemble_size, labels.shape[0]])
            log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy(
                labels_broadcasted, logits_list, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[0]) +
                tf.math.log(float(FLAGS.ensemble_size)))

            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
                metrics['test/stddev'].update_state(stddev)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/stddev_{}'.format(
                    dataset_name)].update_state(stddev)

        strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        for step in range(steps_per_epoch):
            train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_start_time = time.time()
                test_step(test_iterator, dataset_name)
                ms_per_example = (time.time() -
                                  test_start_time) * 1e6 / batch_size
                metrics['test/ms_per_example'].update_state(ms_per_example)

            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types, max_intensity)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        for i in range(FLAGS.ensemble_size):
            logging.info(
                'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
                metrics['test/nll_member_{}'.format(i)].result(),
                metrics['test/accuracy_member_{}'.format(i)].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
 def info(self):
     if not hasattr(self, "_info"):
         self._info = tfds.builder(self.name, data_dir=self.data_dir).info
     return self._info
Beispiel #14
0
import tensorflow as tf
import tensorflow_datasets as tfds
import pickle
import time
import os
import datetime

# Imports with different functions to preprocess the dataset:
from snli_utils import get_vocab, gen_to_encoded_list, encoded_list_to_dataset

# Remember to set your own pickle directories as well as checkpoint and summary directories if you wish different ones.

# Get the dataset: -----------------------------------------------------------------------------------------------------
builder_obj = tfds.builder('snli')
print(builder_obj.info)

datasets_splits_dict = builder_obj.as_dataset()  # returns all splits in a dict
train_dataset, val_dataset, test_dataset = datasets_splits_dict["train"], datasets_splits_dict["validation"], \
                                           datasets_splits_dict["test"]

# Get the iterators to encode the elements into lists of integers representing the words for each sentence:
train_np, val_np, test_np = tfds.as_numpy(train_dataset), tfds.as_numpy(val_dataset), tfds.as_numpy(test_dataset)

# We first need the vocab ----------------------------------------------------------------------------------------------

if not os.path.isfile(r'./generator_vocab.pickle'):
    # Get the vocabulary for the first time:
    vocab = list(get_vocab([train_np, val_np, test_np]))
    print("\nVocab ready!")
else:
    # Get the vocab (obtained by 'get_vocab' function in snli_utils) from the pickle:
Beispiel #15
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    aug_params = {
        'augmix': FLAGS.augmix,
        'aug_count': FLAGS.aug_count,
        'augmix_depth': FLAGS.augmix_depth,
        'augmix_prob_coeff': FLAGS.augmix_prob_coeff,
        'augmix_width': FLAGS.augmix_width,
        'label_smoothing': FLAGS.label_smoothing,
        'ensemble_size': FLAGS.ensemble_size,
        'mixup_alpha': FLAGS.mixup_alpha,
        'random_augment': FLAGS.random_augment,
        'adaptive_mixup': FLAGS.adaptive_mixup,
        'forget_mixup': FLAGS.forget_mixup,
        'num_cores': FLAGS.num_cores,
        'threshold': FLAGS.forget_threshold,
        'cutmix': FLAGS.cutmix,
    }
    batch_size = ((FLAGS.per_core_batch_size // FLAGS.ensemble_size) *
                  FLAGS.num_cores)
    train_input_fn = data_utils.load_input_fn(
        split=tfds.Split.TRAIN,
        name=FLAGS.dataset,
        batch_size=batch_size,
        use_bfloat16=FLAGS.use_bfloat16,
        proportion=FLAGS.train_proportion,
        validation_set=FLAGS.validation,
        aug_params=aug_params)
    if FLAGS.validation:
        validation_input_fn = data_utils.load_input_fn(
            split=tfds.Split.VALIDATION,
            name=FLAGS.dataset,
            batch_size=FLAGS.per_core_batch_size,
            use_bfloat16=FLAGS.use_bfloat16,
            validation_set=True)
        val_dataset = strategy.experimental_distribute_datasets_from_function(
            validation_input_fn)
    clean_test_input_fn = data_utils.load_input_fn(
        split=tfds.Split.TEST,
        name=FLAGS.dataset,
        batch_size=FLAGS.per_core_batch_size // FLAGS.ensemble_size,
        use_bfloat16=FLAGS.use_bfloat16)
    train_dataset = strategy.experimental_distribute_dataset(train_input_fn())
    test_datasets = {
        'clean':
        strategy.experimental_distribute_datasets_from_function(
            clean_test_input_fn),
    }
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar10':
            load_c_dataset = utils.load_cifar10_c
        else:
            load_c_dataset = functools.partial(utils.load_cifar100_c,
                                               path=FLAGS.cifar100_c_path)
        corruption_types, max_intensity = utils.load_corrupted_test_info(
            FLAGS.dataset)
        for corruption in corruption_types:
            for intensity in range(1, max_intensity + 1):
                dataset = load_c_dataset(corruption_name=corruption,
                                         corruption_intensity=intensity,
                                         batch_size=batch_size,
                                         use_bfloat16=FLAGS.use_bfloat16)
                test_datasets['{0}_{1}'.format(corruption, intensity)] = (
                    strategy.experimental_distribute_dataset(dataset))

    ds_info = tfds.builder(FLAGS.dataset).info
    num_train_examples = ds_info.splits['train'].num_examples
    # Train_proportion is a float so need to convert steps_per_epoch to int.
    if FLAGS.validation:
        # TODO(ywenxu): Remove hard-coding validation images.
        steps_per_epoch = int(
            (num_train_examples * FLAGS.train_proportion - 2500) // batch_size)
        steps_per_val = 2500 // (FLAGS.per_core_batch_size * FLAGS.num_cores)
    else:
        steps_per_epoch = int(
            num_train_examples * FLAGS.train_proportion) // batch_size
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building Keras model')
        model = batchensemble_model.wide_resnet(
            input_shape=ds_info.features['image'].shape,
            depth=28,
            width_multiplier=10,
            num_classes=num_classes,
            ensemble_size=FLAGS.ensemble_size,
            random_sign_init=FLAGS.random_sign_init,
            l2=FLAGS.l2,
            use_ensemble_bn=FLAGS.use_ensemble_bn)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                           for start_epoch_str in FLAGS.lr_decay_epochs]
        lr_schedule = utils.LearningRateSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=0.9,
                                            nesterov=True)

        diversity_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            FLAGS.diversity_coeff,
            FLAGS.diversity_decay_epoch * steps_per_epoch,
            decay_rate=0.97,
            staircase=True)

        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/similarity':
            tf.keras.metrics.Mean(),
            'train/l2':
            tf.keras.metrics.Mean(),
            'train/ece':
            um.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/member_accuracy_mean':
            (tf.keras.metrics.SparseCategoricalAccuracy()),
            'test/ece':
            um.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/member_ece_mean':
            um.ExpectedCalibrationError(num_bins=FLAGS.num_bins)
        }
        for i in range(FLAGS.ensemble_size):
            metrics['test/nll_member_{}'.format(i)] = tf.keras.metrics.Mean()
            metrics['test/accuracy_member_{}'.format(i)] = (
                tf.keras.metrics.SparseCategoricalAccuracy())
            metrics['test/ece_member_{}'.format(i)] = (
                um.ExpectedCalibrationError(num_bins=FLAGS.num_bins))

        test_diversity = {}
        training_diversity = {}
        corrupt_diversity = {}
        if FLAGS.ensemble_size > 1:
            test_diversity = {
                'test/disagreement': tf.keras.metrics.Mean(),
                'test/average_kl': tf.keras.metrics.Mean(),
                'test/cosine_similarity': tf.keras.metrics.Mean(),
            }
            training_diversity = {
                'train/disagreement': tf.keras.metrics.Mean(),
                'train/average_kl': tf.keras.metrics.Mean(),
                'train/cosine_similarity': tf.keras.metrics.Mean(),
            }

        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, max_intensity + 1):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        um.ExpectedCalibrationError(num_bins=FLAGS.num_bins))
                    corrupt_metrics['test/member_acc_mean_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/member_ece_mean_{}'.format(
                        dataset_name)] = (um.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))
                    corrupt_diversity['corrupt_diversity/average_kl_{}'.format(
                        dataset_name)] = tf.keras.metrics.Mean()
                    corrupt_diversity[
                        'corrupt_diversity/cosine_similarity_{}'.format(
                            dataset_name)] = tf.keras.metrics.Mean()
                    corrupt_diversity[
                        'corrupt_diversity/disagreement_{}'.format(
                            dataset_name)] = tf.keras.metrics.Mean()

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            if FLAGS.forget_mixup:
                images, labels, idx = inputs
            else:
                images, labels = inputs
            if FLAGS.adaptive_mixup or FLAGS.forget_mixup:
                images = tf.identity(images)
            elif FLAGS.augmix or FLAGS.random_augment:
                images_shape = tf.shape(images)
                images = tf.reshape(
                    tf.transpose(images, [1, 0, 2, 3, 4]),
                    [-1, images_shape[2], images_shape[3], images_shape[4]])
            else:
                images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
            # Augmix, adaptive mixup, forget mixup preprocessing gives tiled labels.
            if FLAGS.mixup_alpha > 0 or FLAGS.label_smoothing > 0 or FLAGS.cutmix:
                if FLAGS.augmix or FLAGS.adaptive_mixup or FLAGS.forget_mixup:
                    labels = tf.identity(labels)
                else:
                    labels = tf.tile(labels, [FLAGS.ensemble_size, 1])
            else:
                labels = tf.tile(labels, [FLAGS.ensemble_size])

            def _is_batch_norm(v):
                """Decide whether a variable belongs to `batch_norm`."""
                keywords = ['batchnorm', 'batch_norm', 'bn']
                return any([k in v.name.lower() for k in keywords])

            def _normalize(x):
                """Normalize an input with l2 norm."""
                l2 = tf.norm(x, ord=2, axis=-1)
                return x / tf.expand_dims(l2, axis=-1)

            # Taking the sum of upper triangular of XX^T and divided by ensemble size.
            def pairwise_cosine_distance(x):
                """Compute the pairwise distance in a matrix."""
                normalized_x = _normalize(x)
                return (tf.reduce_sum(
                    tf.matmul(normalized_x, normalized_x, transpose_b=True)) -
                        FLAGS.ensemble_size) / (2.0 * FLAGS.ensemble_size)

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                if FLAGS.mixup_alpha > 0 or FLAGS.label_smoothing > 0 or FLAGS.cutmix:
                    negative_log_likelihood = tf.reduce_mean(
                        tf.keras.losses.categorical_crossentropy(
                            labels, logits, from_logits=True))
                else:
                    negative_log_likelihood = tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(
                            labels, logits, from_logits=True))

                l2_loss = sum(model.losses)
                fast_weights = [
                    var for var in model.trainable_variables
                    if not _is_batch_norm(var) and (
                        'alpha' in var.name or 'gamma' in var.name)
                ]

                pairwise_distance_loss = tf.add_n(
                    [pairwise_cosine_distance(var) for var in fast_weights])

                diversity_start_iter = steps_per_epoch * FLAGS.diversity_start_epoch
                diversity_iterations = optimizer.iterations - diversity_start_iter
                if diversity_iterations > 0:
                    diversity_coeff = diversity_schedule(diversity_iterations)
                    diversity_loss = diversity_coeff * pairwise_distance_loss
                    loss = negative_log_likelihood + l2_loss + diversity_loss
                else:
                    loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)

            # Separate learning rate implementation.
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = []
                for grad, var in zip(grads, model.trainable_variables):
                    # Apply different learning rate on the fast weight approximate
                    # posterior/prior parameters. This is excludes BN and slow weights,
                    # but pay caution to the naming scheme.
                    if (not _is_batch_norm(var) and 'kernel' not in var.name):
                        grads_and_vars.append(
                            (grad * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grads_and_vars.append((grad, var))
                optimizer.apply_gradients(grads_and_vars)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            probs = tf.nn.softmax(logits)
            if FLAGS.ensemble_size > 1:
                per_probs = tf.reshape(
                    probs,
                    tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0))
                diversity_results = um.average_pairwise_diversity(
                    per_probs, FLAGS.ensemble_size)
                for k, v in diversity_results.items():
                    training_diversity['train/' + k].update_state(v)

            if FLAGS.mixup_alpha > 0 or FLAGS.label_smoothing > 0 or FLAGS.cutmix:
                labels = tf.argmax(labels, axis=-1)
            metrics['train/ece'].update_state(labels, probs)
            metrics['train/similarity'].update_state(pairwise_distance_loss)
            metrics['train/l2'].update_state(l2_loss)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)
            if FLAGS.forget_mixup:
                train_predictions = tf.argmax(probs, -1)
                labels = tf.cast(labels, train_predictions.dtype)
                # For each ensemble member, we accumulate the accuracy counts.
                accuracy_counts = tf.cast(
                    tf.reshape((train_predictions == labels),
                               [FLAGS.ensemble_size, -1]), tf.float32)
                return accuracy_counts, idx

        if FLAGS.forget_mixup:
            return strategy.run(step_fn, args=(next(iterator), ))
        else:
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
            logits = model(images, training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.softmax(logits)
            per_probs = tf.split(probs,
                                 num_or_size_splits=FLAGS.ensemble_size,
                                 axis=0)
            for i in range(FLAGS.ensemble_size):
                member_probs = per_probs[i]
                if dataset_name == 'clean':
                    member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                        labels, member_probs)
                    metrics['test/nll_member_{}'.format(i)].update_state(
                        member_loss)
                    metrics['test/accuracy_member_{}'.format(i)].update_state(
                        labels, member_probs)
                    metrics['test/member_accuracy_mean'].update_state(
                        labels, member_probs)
                    metrics['test/ece_member_{}'.format(i)].update_state(
                        labels, member_probs)
                    metrics['test/member_ece_mean'].update_state(
                        labels, member_probs)
                elif dataset_name != 'validation':
                    corrupt_metrics['test/member_acc_mean_{}'.format(
                        dataset_name)].update_state(labels, member_probs)
                    corrupt_metrics['test/member_ece_mean_{}'.format(
                        dataset_name)].update_state(labels, member_probs)

            if FLAGS.ensemble_size > 1:
                per_probs_tensor = tf.reshape(
                    probs,
                    tf.concat([[FLAGS.ensemble_size, -1], probs.shape[1:]], 0))
                diversity_results = um.average_pairwise_diversity(
                    per_probs_tensor, FLAGS.ensemble_size)
                if dataset_name == 'clean':
                    for k, v in diversity_results.items():
                        test_diversity['test/' + k].update_state(v)
                elif dataset_name != 'validation':
                    for k, v in diversity_results.items():
                        corrupt_diversity['corrupt_diversity/{}_{}'.format(
                            k, dataset_name)].update_state(v)

            probs = tf.reduce_mean(per_probs, axis=0)
            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))
            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
            elif dataset_name != 'validation':
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(
                    dataset_name)].update_state(labels, probs)

            if dataset_name == 'validation':
                return per_probs_tensor, labels

        if dataset_name == 'validation':
            return strategy.run(step_fn, args=(next(iterator), ))
        else:
            strategy.run(step_fn, args=(next(iterator), ))

    train_iterator = iter(train_dataset)
    start_time = time.time()
    forget_counts_history = []
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        acc_counts_list = []
        idx_list = []
        for step in range(steps_per_epoch):
            if FLAGS.forget_mixup:
                temp_accuracy_counts, temp_idx = train_step(train_iterator)
                acc_counts_list.append(temp_accuracy_counts)
                idx_list.append(temp_idx)
            else:
                train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        # Only one of the forget_mixup and adaptive_mixup can be true.
        if FLAGS.forget_mixup:
            current_acc = [
                tf.concat(list(acc_counts_list[i].values), axis=1)
                for i in range(len(acc_counts_list))
            ]
            total_idx = [
                tf.concat(list(idx_list[i].values), axis=0)
                for i in range(len(idx_list))
            ]
            current_acc = tf.cast(tf.concat(current_acc, axis=1), tf.int32)
            total_idx = tf.concat(total_idx, axis=0)

            current_forget_path = os.path.join(FLAGS.output_dir,
                                               'forget_counts.npy')
            last_acc_path = os.path.join(FLAGS.output_dir, 'last_acc.npy')
            if epoch == 0:
                forget_counts = tf.zeros(
                    [FLAGS.ensemble_size, num_train_examples], dtype=tf.int32)
                last_acc = tf.zeros([FLAGS.ensemble_size, num_train_examples],
                                    dtype=tf.int32)
            else:
                if 'last_acc' not in locals():
                    with tf.io.gfile.GFile(last_acc_path, 'rb') as f:
                        last_acc = np.load(f)
                    last_acc = tf.cast(tf.convert_to_tensor(last_acc),
                                       tf.int32)
                if 'forget_counts' not in locals():
                    with tf.io.gfile.GFile(current_forget_path, 'rb') as f:
                        forget_counts = np.load(f)
                    forget_counts = tf.cast(
                        tf.convert_to_tensor(forget_counts), tf.int32)

            selected_last_acc = tf.gather(last_acc, total_idx, axis=1)
            forget_this_epoch = tf.cast(current_acc < selected_last_acc,
                                        tf.int32)
            forget_this_epoch = tf.transpose(forget_this_epoch)
            target_shape = tf.constant(
                [num_train_examples, FLAGS.ensemble_size])
            current_forget_counts = tf.scatter_nd(
                tf.reshape(total_idx, [-1, 1]), forget_this_epoch,
                target_shape)
            current_forget_counts = tf.transpose(current_forget_counts)
            acc_this_epoch = tf.transpose(current_acc)
            last_acc = tf.scatter_nd(tf.reshape(total_idx, [-1, 1]),
                                     acc_this_epoch, target_shape)
            # This is lower bound of true acc.
            last_acc = tf.transpose(last_acc)

            # TODO(ywenxu): We count the dropped examples as forget. Fix this later.
            forget_counts += current_forget_counts
            forget_counts_history.append(forget_counts)
            logging.info('forgetting counts')
            logging.info(tf.stack(forget_counts_history, 0))
            with tf.io.gfile.GFile(
                    os.path.join(FLAGS.output_dir,
                                 'forget_counts_history.npy'), 'wb') as f:
                np.save(f, tf.stack(forget_counts_history, 0).numpy())
            with tf.io.gfile.GFile(current_forget_path, 'wb') as f:
                np.save(f, forget_counts.numpy())
            with tf.io.gfile.GFile(last_acc_path, 'wb') as f:
                np.save(f, last_acc.numpy())
            aug_params['forget_counts_dir'] = current_forget_path

            train_input_fn = data_utils.load_input_fn(
                split=tfds.Split.TRAIN,
                name=FLAGS.dataset,
                batch_size=FLAGS.num_cores *
                (FLAGS.per_core_batch_size // FLAGS.ensemble_size),
                use_bfloat16=FLAGS.use_bfloat16,
                validation_set=FLAGS.validation,
                aug_params=aug_params)
            train_dataset = strategy.experimental_distribute_dataset(
                train_input_fn())
            train_iterator = iter(train_dataset)

        if FLAGS.adaptive_mixup:
            val_iterator = iter(val_dataset)
            logging.info('Testing on validation dataset')
            predictions_list = []
            labels_list = []
            for step in range(steps_per_val):
                temp_predictions, temp_labels = test_step(
                    val_iterator, 'validation')
                predictions_list.append(temp_predictions)
                labels_list.append(temp_labels)
            predictions = [
                tf.concat(list(predictions_list[i].values), axis=1)
                for i in range(len(predictions_list))
            ]
            labels = [
                tf.concat(list(labels_list[i].values), axis=0)
                for i in range(len(labels_list))
            ]
            predictions = tf.concat(predictions, axis=1)
            labels = tf.cast(tf.concat(labels, axis=0), tf.int64)

            def compute_acc_conf(preds, label, focus_class):
                class_preds = tf.boolean_mask(preds,
                                              label == focus_class,
                                              axis=1)
                class_pred_labels = tf.argmax(class_preds, axis=-1)
                confidence = tf.reduce_mean(
                    tf.reduce_max(class_preds, axis=-1), -1)
                accuracy = tf.reduce_mean(tf.cast(
                    class_pred_labels == focus_class, tf.float32),
                                          axis=-1)
                return accuracy - confidence

            calibration_per_class = [
                compute_acc_conf(predictions, labels, i)
                for i in range(num_classes)
            ]
            calibration_per_class = tf.stack(calibration_per_class, axis=1)
            logging.info('calibration per class')
            logging.info(calibration_per_class)
            mixup_coeff = tf.where(calibration_per_class > 0, 1.0,
                                   FLAGS.mixup_alpha)
            mixup_coeff = tf.clip_by_value(mixup_coeff, 0, 1)
            logging.info('mixup coeff')
            logging.info(mixup_coeff)
            aug_params['mixup_coeff'] = mixup_coeff
            train_input_fn = data_utils.load_input_fn(
                split=tfds.Split.TRAIN,
                name=FLAGS.dataset,
                batch_size=FLAGS.per_core_batch_size // FLAGS.ensemble_size,
                use_bfloat16=FLAGS.use_bfloat16,
                validation_set=True,
                aug_params=aug_params)
            train_dataset = strategy.experimental_distribute_datasets_from_function(
                train_input_fn)
            train_iterator = iter(train_dataset)

        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_step(test_iterator, dataset_name)
            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            # This includes corrupt_diversity whose disagreement normalized by its
            # corrupt mean error rate.
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics,
                corruption_types,
                max_intensity,
                corrupt_diversity=corrupt_diversity,
                output_dir=FLAGS.output_dir)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        for i in range(FLAGS.ensemble_size):
            logging.info(
                'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
                metrics['test/nll_member_{}'.format(i)].result(),
                metrics['test/accuracy_member_{}'.format(i)].result() * 100)

        total_metrics = metrics.copy()
        total_metrics.update(training_diversity)
        total_metrics.update(test_diversity)
        total_results = {
            name: metric.result()
            for name, metric in total_metrics.items()
        }
        total_results.update(corrupt_results)
        # Normalize all disagreement metrics (training, testing) by test accuracy.
        # Disagreement on corrupt dataset is normalized by their own error rate.
        test_acc = total_metrics['test/accuracy'].result()
        for name, metric in total_metrics.items():
            if 'disagreement' in name:
                total_results[name] = metric.result() / test_acc

        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)
            if FLAGS.forget_mixup:
                tf.summary.histogram('forget_counts',
                                     forget_counts,
                                     step=epoch + 1)

        for metric in total_metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

        # Need to store the last but one checkpoint in adaptive mixup setup.
        if FLAGS.adaptive_mixup and epoch == (FLAGS.train_epochs - 2):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'last_but_one_checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
    final_save_name = os.path.join(FLAGS.output_dir, 'model')
    model.save(final_save_name)
    logging.info('Saved model to %s', final_save_name)
def get_dataset(data_dir, config, dataset_name=None):
    """The training dataset for the code model for fault localization.

  Args:
    data_dir: The data directory to use with tfds.load.
    config: The config for the model.
    dataset_name: If set, use this dataset name in place of the one from the
      config.
  Returns:
    train_dataset: The tf.data.Dataset with batched examples.
    info: The DatasetInfo object containing the feature connectors and other
      info about the dataset.
  """
    dataset_name = dataset_name or config.dataset.name
    split = get_split(config)
    version = (None if config.dataset.version == 'default' else
               config.dataset.version)

    # If in interact mode, use an interactive dataset.
    if config.runner.mode == 'interact':
        dbuilder = tfds.builder(dataset_name,
                                data_dir=data_dir,
                                version=version)
        unused_split_generators = dbuilder._split_generators(dl_manager=None)  # pylint: disable=protected-access
        info = dbuilder.info
        info._builder.set_representation(config.dataset.representation)  # pylint: disable=protected-access
        assert config.dataset.batch_size == 1
        dataset = make_interactive_dataset(info, config)
        if config.dataset.batch:
            dataset = apply_batching(dataset, info, config)
        set_task = cannot_set_task
        return DatasetInfo(dataset=dataset, info=info, set_task=set_task)

    # Load the dataset.
    if config.dataset.in_memory:
        dbuilder = tfds.builder(dataset_name,
                                data_dir=data_dir,
                                version=version)
        unused_split_generators = dbuilder._split_generators(dl_manager=None)  # pylint: disable=protected-access
        dataset, set_task = dbuilder.as_in_memory_dataset(split='all')
        info = dbuilder.info
    else:
        name = dataset_name
        if version is not None:
            name = f'{name}:{version}'
        dataset, info = tfds.load(
            name=name,
            split=split,
            data_dir=data_dir,
            # batch_size=config.dataset.batch_size,
            with_info=True)
        set_task = cannot_set_task

    info._builder.set_representation(config.dataset.representation)  # pylint: disable=protected-access

    verify_reasonable_dataset(dataset_name, info, config)
    dataset = dataset.repeat()
    dataset = apply_filtering(dataset, info, config)
    if config.dataset.batch:
        dataset = apply_batching(dataset, info, config)
    return DatasetInfo(
        dataset=dataset,
        info=info,
        set_task=set_task,
    )
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

mnist_bldr = tfds.builder('mnist')
mnist_bldr.download_and_prepare()
datasets = mnist_bldr.as_dataset(shuffle_files=False)
mnist_train_orig = datasets['train']
mnist_test_orig = datasets['test']

BUFFER_SIZE = 10000
BATCH_SIZE = 64
NUM_EPOCHS = 20
mnist_train = mnist_train_orig.map(lambda item: (tf.cast(
    item['image'], tf.float32) / 255.0, tf.cast(item['label'], tf.int32)))

mnist_test = mnist_train_orig.map(lambda item: (tf.cast(
    item['image'], tf.float32) / 255.0, tf.cast(item['label'], tf.int32)))

tf.random.set_seed(1)
mnist_train = mnist_train.shuffle(buffer_size=BUFFER_SIZE,
                                  reshuffle_each_iteration=False)
mnist_valid = mnist_train.take(10000).batch(BATCH_SIZE)
mnist_train = mnist_train.skip(10000).batch(BATCH_SIZE)

model = tf.keras.Sequential()

model.add(
Beispiel #18
0
def train(pooling = 'global',
    backbonetype='mobilenetv2',
    output_stride = 8,
    residual_shortcut = False, 
    height_image = 448, 
    width_image = 448, 
    channels = 3, 
    crop_enable = False,
    height_crop = 448, 
    width_crop = 448, 
    
    debug_en = False, 
    dataset_name = 'freiburg_forest', 
    class_imbalance_correction = False, 
    data_augmentation = True, 
     
    multigpu_enable = True,
    
    batch_size = 32, 
    epochs = 300, 
    initial_epoch = -1, 
    continue_traning = False, 
    fine_tune_last = False, 
    
    base_learning_rate = 0.007, 
    learning_power = 0.98, 
    decay_steps = 1,
    learning_rate_decay_step = 300, 
    
    decay = 5**(-4), 
    validation_enable=False
    ):



    ## Check the number of labels
    if dataset_name == 'cityscape':
        classes = cityscape.classes
    elif dataset_name == 'citysmall':
        classes = citysmall.classes
    elif dataset_name == 'off_road_small':
        classes = off_road_small.classes
    elif dataset_name == 'freiburg_forest':
        classes = freiburg_forest.classes
    
    n_classes = len(classes)
    ignore_label = classes[-1]['name']=='ignore'
    ignore_label = False
    
    #n_classes = 12

    if crop_enable:
        assert(height_crop == width_crop, 
               "When crop is enable height_crop should be equals to width_crop")
        assert(height_crop <= height_image, 
               "When crop is enable height_crop should be less than or equals to height_image")
        assert(width_crop <= width_image, 
               "When crop is enable height_crop should be less than or equals to width_image")
    else:
        height_crop = height_image
        width_crop = width_image
    
    # Construct a tf.data.Dataset
    info = tfds.builder(dataset_name).info
    print(info)
    # if validation_enable : #7%
    #     ds_train, ds_val, ds_test = tfds.load(name=dataset_name, split=['train[:-10%]', 'train[-10%:]', 'test'], as_supervised=True)
    # else:
    #     ds_train, ds_test = tfds.load(name=dataset_name, split=["train", 'test'], as_supervised=True)
        
    train, ds_val, ds_test = tfds.load(name=dataset_name, 
                                          split=[tfds.Split.TRAIN, 
                                                 tfds.Split.VALIDATION, 
                                                 tfds.Split.TEST], 
                                          as_supervised=True)

    height_ds = info.features['image'].shape[-3]
    width_ds = info.features['image'].shape[-2]
    # Add normalize
    def _normalize_img(image, label):
        image = tf.cast(image, tf.float32)/127.5 - 1
        if crop_enable:
            y1 = tf.random.uniform(shape=[], minval=0., maxval=(height_image-height_crop)/height_image, dtype=tf.float32)
            x1 = tf.random.uniform(shape=[], minval=0., maxval=(width_image-width_crop)/width_image, dtype=tf.float32)
    
            y2 = y1 + (height_crop/height_image)
            x2 = x1 + (width_crop/width_image)
    
            boxes = [[y1, x1, y2, x2]]
            image = tf.image.crop_and_resize([image], boxes, box_indices=[0], crop_size=(height_crop, width_crop), method=tf.image.ResizeMethod.BILINEAR)[0]
            label = tf.cast(tf.image.crop_and_resize([label], boxes, box_indices=[0], crop_size=(height_crop, width_crop), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)[0],dtype=tf.uint8)
        else:
            image = tf.image.resize(image, (height_image,width_image), method=tf.image.ResizeMethod.BILINEAR)
            label = tf.image.resize(label, (height_image,width_image), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        return (image, label)
    
    # ds_train = ds_train.map(_normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # if validation_enable :
    #     ds_val = ds_val.map(_normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    if validation_enable :
        ds_train = train
        train_size = info.splits['train'].num_examples
    else:
        ds_train = train.concatenate(ds_val)
        train_size = info.splits['train'].num_examples + info.splits['validation'].num_examples
    # ds_train = ds_train.take(100)
    ds_train = ds_train.map(_normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_val = ds_val.map(_normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    

    ########################################################debug
    # ds_test = ds_test.map(_normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # ds_test = ds_test.shuffle(124).batch(batch_size).prefetch(10)
    ########################################################debug
    
    if data_augmentation:
        # Add augmentations
        augmentations = [aug.flip, aug.color, aug.zoom, aug.rotate]
        
        for f in augmentations:
            ds_train = ds_train.map(lambda x, y: tf.cond(tf.random.uniform([], 0, 1) > 0.75, lambda: f(x, y), lambda: (x, y)),
                                    num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds_train = ds_train.map(lambda x, y: (tf.clip_by_value(x, -1, 1), y),  num_parallel_calls=tf.data.experimental.AUTOTUNE)
        if debug_en:
            aug.plot_images(ds_train, n_images=8, samples_per_image=10, classes = classes)
    
    
    # Build your input pipeline
    ds_train = ds_train.repeat().shuffle(124).batch(batch_size).prefetch(10)
    ds_val = ds_val.shuffle(124).batch(batch_size).prefetch(10)
    # ds_train = ds_train.shuffle(124).batch(batch_size).prefetch(10)
    # if validation_enable :
    #     ds_val = ds_val.shuffle(124).batch(batch_size).prefetch(10)
    # else:
    #     ds_val = None
    
    # validation_steps=int(round(info.splits['test'].num_examples/batch_size))
    steps_per_epoch=int(round(train_size/batch_size))
    step_size = steps_per_epoch
    
    class MIoU(MeanIoU):
      def __init__(self, num_classes, name=None, dtype=None):
        super(MIoU, self).__init__(num_classes=num_classes, name=name, dtype=dtype)
    
      def update_state(self, y_true, y_pred, sample_weight=None):
        return super(MIoU, self).update_state(
                y_true=y_true, 
                y_pred=tf.math.argmax(input=y_pred, axis=-1, output_type=tf.dtypes.int64), 
                sample_weight=sample_weight)
    
    # class_imbalance_correction
    
    fold_name = (backbonetype+'s'+str(output_stride)+'_pooling_' + pooling
                 +('_residual_shortcut' if residual_shortcut else '')+'_ep'+str(epochs)
                 +('_crop_'+str(height_crop)+'x'+str(width_crop) if crop_enable else '_')
                 +('from' if crop_enable else '')+str(height_image)+'x'+str(width_image)
                 +'_'+('wda_' if data_augmentation else 'nda_')
                 +('wcic_' if class_imbalance_correction else 'ncic_')
                 +('wft_' if fine_tune_last else 'nft_')+dataset_name+'_b'
                 +str(batch_size)+('_n' if multigpu_enable else '_1')+'gpu')
    
    # multigpu_enable = False
    if multigpu_enable:
        strategy = tf.distribute.MirroredStrategy()
        with strategy.scope():
            cmsnet = CMSNet(dl_input_shape=(None, height_crop, width_crop, channels),
                            num_classes=n_classes, output_stride=output_stride,
                            pooling=pooling, residual_shortcut=residual_shortcut,
                            backbonetype=backbonetype)
            cmsnet.summary()
            #cmsnet.mySummary()
    
            optimizer = SGD(momentum=0.9, nesterov=True)
            #optimizer = RMSprop(lr=0.001, rho=0.9, epsilon=1e-6)
            #optimizer = Adadelta(lr=0.008, rho=0.95, epsilon=None, decay=0.0)
            # optimizer = Adamax(lr=0.002, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0)
            #optimizer = Nadam(lr=0.002, beta_1=0.9, beta_2=0.999, epsilon=None, schedule_decay=0.004)
            miou = MIoU(num_classes=n_classes)
    
            cmsnet.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer,
                          metrics=['accuracy', miou]
                          )
    else:
        cmsnet = CMSNet(dl_input_shape=(None, height_crop, width_crop, channels),
                        num_classes=n_classes, output_stride=output_stride, pooling=pooling,
                        residual_shortcut=residual_shortcut, backbonetype=backbonetype)
        cmsnet.summary()
        
        optimizer = SGD(momentum=0.9, nesterov=True)
        #optimizer = RMSprop(lr=0.001, rho=0.9, epsilon=1e-6)
        #optimizer = Adadelta(lr=0.008, rho=0.95, epsilon=None, decay=0.0)
        # optimizer = Adamax(lr=0.002, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0)
        #optimizer = Nadam(lr=0.002, beta_1=0.9, beta_2=0.999, epsilon=None, schedule_decay=0.004)
    
        miou = MIoU(num_classes=n_classes)
        cmsnet.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer,
                      metrics=['accuracy', miou])
    
    

    
    #fold_name = 's16_wda_ncic_nft_off_road_small_b8_ngpu_ep300_483x769_pooling_aspp_20190815-151551'
    
    # Define the Keras TensorBoard callback.
    if continue_traning:
        logdir=build_path+"logs/fit/" + fold_name #Continue
        if initial_epoch == -1: #get last checkpoint epoch
            names = glob.glob(logdir+'*/weights.*')
            names.sort()
            initial_epoch = int(names[-1].split('.')[-4].split('-')[0])

        #weights_path = glob.glob(logdir+'*/weights.*'+str(initial_epoch)+'-*')[0]
        weights_path = glob.glob(logdir+'*/weights.last*')[0]
        logdir = weights_path[:weights_path.find('/weights.')]
    
        print('Continuing train from '+ weights_path)
        if multigpu_enable:
            with strategy.scope():
                cmsnet.load_weights(weights_path)
        else:
            cmsnet.load_weights(weights_path)
    else:
        logdir=build_path+"logs/fit/" + fold_name+'_'+datetime.now().strftime("%Y%m%d-%H%M%S")
        initial_epoch = 0
    
    
    
    
    # if validation_enable : #7%
    #     ckp = ModelCheckpoint(logdir+"/weights.{epoch:03d}-{val_loss:.2f}-{val_m_io_u:.2f}.hdf5",
    #                       monitor='val_m_io_u', mode='max',  verbose=0, save_best_only=True,
    #                       save_weights_only=True, period=1)
    # else:
    # ckp = ModelCheckpoint(logdir+"/weights.{epoch:03d}-{loss:.2f}-{m_io_u:.2f}.hdf5",
    #                   monitor='val_m_io_u', mode='max',  verbose=0, save_best_only=True,
    #                   save_weights_only=True, period=1)
    # ckp = ModelCheckpoint(logdir+"/weights.{epoch:03d}.hdf5",
    #                       mode='max',  verbose=0, save_best_only=True,
    #                       save_weights_only=True, period=1)
    ckp = ModelCheckpoint(logdir+"/weights.{epoch:03d}-{val_loss:.2f}-{val_m_io_u:.2f}.hdf5",
                          monitor='val_m_io_u', mode='max',  verbose=0, save_best_only=True,
                          save_weights_only=True, period=1)
    ckp_last = ModelCheckpoint(logdir+"/weights.last.hdf5", verbose=0, save_best_only=False,
                          save_weights_only=True, period=1)
    
    tensorboard_callback = LRTensorBoard(log_dir=logdir, histogram_freq=0, 
                                         write_graph=True, write_images=True,
                                         update_freq='epoch', profile_batch=2, 
                                         embeddings_freq=0)
    
    #if aplay balance
    # from sklearn.utils import class_weight
    # class_weight = class_weight.compute_class_weight('balanced'
    #                                                ,np.unique(Y_train)
    #                                                ,Y_train)
    
    # class_weight = {cls:1 for cls in range(n_classes)}
    # #ignore the last label
    # if ignore_label:
    #     class_weight[n_classes-1] = 0
    
    if class_imbalance_correction:
        class_weight = np.ones(n_classes) #TODO: insert inbalance equalizer
    elif ignore_label:
        class_weight = np.ones(n_classes)
    else:
        class_weight = None
    #ignore the last label
    if ignore_label:
        class_weight[-1] = 0
    
    class_weight=None
    
    

    
    
    if fine_tune_last:
        base_learning_rate = 0.0001
        cmsnet.setFineTuning(lavel='fromSPPool')
        # learning = lambda epoch: polynomial_decay(epoch, initial_lrate = base_learning_rate,
        #     learning_rate_decay_step=learning_rate_decay_step, learning_power=learning_power, 
        #     end_learning_rate=0.0001, cycle=False)
        learning = lambda epoch: exponential_decay(learning_rate=base_learning_rate, global_step=epoch, decay_steps=decay_steps, decay_rate=learning_power)
    
        lrate = LearningRateScheduler(learning)
        hist1 = cmsnet.fit(ds_train, validation_data=ds_val, epochs=epochs,
                      callbacks=[tensorboard_callback, lrate, ckp],
                      initial_epoch=initial_epoch,  class_weight=class_weight)
    else:
    
        cmsnet.setFineTuning(lavel='fromAll')
        # learning = lambda epoch: polynomial_decay(epoch, initial_lrate = base_learning_rate,
        #         learning_rate_decay_step=learning_rate_decay_step, learning_power=learning_power, 
        #         end_learning_rate=0, cycle=False)
        learning = lambda epoch: exponential_decay(learning_rate=base_learning_rate, global_step=epoch, decay_steps=decay_steps, decay_rate=learning_power)
    
        lrate = LearningRateScheduler(learning)
        cmsnet.fit(ds_train, validation_data = ds_val, 
                   epochs = round(epochs*steps_per_epoch/step_size), 
                   steps_per_epoch = step_size,
                   callbacks = [tensorboard_callback, lrate, ckp, ckp_last],
                   initial_epoch = initial_epoch,  
                   class_weight = class_weight,
                   use_multiprocessing = True)
        
        ########################################################debug
        # result = cmsnet.evaluate(ds_train, use_multiprocessing=True)
        # import json
        # with open(build_path+'reslut0.txt', 'w') as file:
        #     file.write(json.dumps(result))
            
        # classes_ids = [1, 2, 3, 4, 0]
        # from evaluation.trained import print_result2
        # print(result['name'])
        # print("Params: "+str(result['count_params']))
        # print_result2(result['classes'], np.array(result['confusion_matrix']), classes_ids)
        ########################################################debug
    
    # ########################################################debug
    # result = cmsnet.evaluate(ds_test, use_multiprocessing=True)
    # import json
    # with open(build_path+'reslut0.txt', 'w') as file:
    #     file.write(json.dumps(result))
    
    # classes_ids = [1, 2, 3, 4, 0]
    # from evaluation.trained import print_result2
    # print(result['name'])
    # print("Params: "+str(result['count_params']))
    # print_result2(result['classes'], np.array(result['confusion_matrix']), classes_ids)
    # ########################################################debug
    
    cmsnet.save_weights(logdir+"/weights.300-n.nn-n.nn.hdf5")
Beispiel #19
0
import time
import numpy as np
# import matplotlib as mpl
# import matplotlib.pyplot as plt
from pprint import pprint

import tensorflow as tf
import tensorflow_datasets as tfds
print(tf.__version__)

output_dir = "nmt"
en_vocab_file = os.path.join(output_dir, "en_vocab")
zh_vocab_file = os.path.join(output_dir, "zh_vocab")
checkpoint_path = os.path.join(output_dir, "checkpoints")
log_dir = os.path.join(output_dir, 'logs')
download_dir = "tensorflow-datasets/downloads"

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

tmp_builder = tfds.builder("wmt19_translate/zh-en")
pprint(tmp_builder.subsets)

config = tfds.translate.wmt.WmtConfig(
    version=tfds.core.Version('0.0.3',
                              experiments={tfds.core.Experiment.S3: False}),
    language_pair=("zh", "en"),
    subsets={tfds.Split.TRAIN: ["newscommentary_v14"]})
builder = tfds.builder("wmt_translate", config=config)
builder.download_and_prepare(download_dir=download_dir)
Beispiel #20
0
def main(argv):
    del argv  # unused arg
    tf.enable_v2_behavior()
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)

    train_input_fn = utils.load_input_fn(split=tfds.Split.TRAIN,
                                         name=FLAGS.dataset,
                                         batch_size=FLAGS.per_core_batch_size,
                                         use_bfloat16=FLAGS.use_bfloat16)
    clean_test_input_fn = utils.load_input_fn(
        split=tfds.Split.TEST,
        name=FLAGS.dataset,
        batch_size=FLAGS.per_core_batch_size,
        use_bfloat16=FLAGS.use_bfloat16)
    train_dataset = strategy.experimental_distribute_datasets_from_function(
        train_input_fn)

    test_datasets = {
        'clean':
        strategy.experimental_distribute_datasets_from_function(
            clean_test_input_fn),
    }
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar10':
            load_c_input_fn = utils.load_cifar10_c_input_fn
        else:
            load_c_input_fn = functools.partial(utils.load_cifar100_c_input_fn,
                                                path=FLAGS.cifar100_c_path)
        corruption_types, max_intensity = utils.load_corrupted_test_info(
            FLAGS.dataset)
        for corruption in corruption_types:
            for intensity in range(1, max_intensity + 1):
                input_fn = load_c_input_fn(
                    corruption_name=corruption,
                    corruption_intensity=intensity,
                    batch_size=FLAGS.per_core_batch_size,
                    use_bfloat16=FLAGS.use_bfloat16)
                test_datasets['{0}_{1}'.format(corruption, intensity)] = (
                    strategy.experimental_distribute_datasets_from_function(
                        input_fn))

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    train_dataset_size = ds_info.splits['train'].num_examples
    steps_per_epoch = train_dataset_size // batch_size
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building ResNet model')
        model = wide_resnet(input_shape=ds_info.features['image'].shape,
                            depth=28,
                            width_multiplier=10,
                            num_classes=num_classes,
                            prior_stddev=FLAGS.prior_stddev,
                            dataset_size=train_dataset_size,
                            stddev_init=FLAGS.stddev_init)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(start_epoch * FLAGS.train_epochs) // 200
                           for start_epoch in FLAGS.lr_decay_epochs]
        lr_schedule = utils.LearningRateSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=0.9,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            ed.metrics.ExpectedCalibrationError(num_classes=num_classes,
                                                num_bins=FLAGS.num_bins),
            'train/kl':
            tf.keras.metrics.Mean(),
            'train/kl_scale':
            tf.keras.metrics.Mean(),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            ed.metrics.ExpectedCalibrationError(num_classes=num_classes,
                                                num_bins=FLAGS.num_bins),
        }
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, max_intensity + 1):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        ed.metrics.ExpectedCalibrationError(
                            num_classes=num_classes, num_bins=FLAGS.num_bins))

        global_step = tf.Variable(
            0,
            trainable=False,
            name='global_step',
            dtype=tf.int64,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
        checkpoint = tf.train.Checkpoint(model=model,
                                         optimizer=optimizer,
                                         global_step=global_step)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))

                filtered_variables = []
                for var in model.trainable_variables:
                    # Apply l2 on the BN parameters and bias terms. This
                    # excludes only fast weight approximate posterior/prior parameters,
                    # but pay caution to their naming scheme.
                    if 'batch_norm' in var.name or 'bias' in var.name:
                        filtered_variables.append(tf.reshape(var, (-1, )))

                l2_loss = FLAGS.l2 * 2 * tf.nn.l2_loss(
                    tf.concat(filtered_variables, axis=0))
                kl = sum(model.losses)
                kl_scale = tf.cast(global_step + 1, tf.float32)
                kl_scale /= steps_per_epoch * FLAGS.kl_annealing_epochs
                kl_scale = tf.minimum(1., kl_scale)
                kl_loss = kl_scale * kl

                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                loss = negative_log_likelihood + l2_loss + kl_loss
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].update_state(labels, probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/kl'].update_state(kl)
            metrics['train/kl_scale'].update_state(kl_scale)
            metrics['train/accuracy'].update_state(labels, logits)

            global_step.assign_add(1)

        strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            # TODO(trandustin): Use more eval samples only on corrupted predictions;
            # it's expensive but a one-time compute if scheduled post-training.
            if FLAGS.num_eval_samples > 1 and dataset_name != 'clean':
                logits = tf.stack([
                    model(images, training=False)
                    for _ in range(FLAGS.num_eval_samples)
                ],
                                  axis=0)
            else:
                logits = model(images, training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.softmax(logits)
            if FLAGS.num_eval_samples > 1 and dataset_name != 'clean':
                probs = tf.reduce_mean(probs, axis=0)
            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))

            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(
                    dataset_name)].update_state(labels, probs)

        strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        for step in range(steps_per_epoch):
            train_step(train_iterator)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)

        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_step(test_iterator, dataset_name)
            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types, max_intensity)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
def main(_):
    if FLAGS.module_import:
        import_modules(FLAGS.module_import)

    if FLAGS.debug_start:
        pdb.set_trace()
    if FLAGS.sleep_start:
        time.sleep(60 * 60 * 3)

    if FLAGS.disable_tqdm:
        logging.info("Disabling tqdm.")
        tfds.disable_progress_bar()

    if FLAGS.checksums_dir:
        tfds.download.add_checksums_dir(FLAGS.checksums_dir)

    datasets_to_build = set(FLAGS.datasets and FLAGS.datasets.split(",")
                            or tfds.list_builders())
    datasets_to_build -= set(FLAGS.exclude_datasets.split(","))
    version = "experimental_latest" if FLAGS.experimental_latest_version else None
    logging.info("Running download_and_prepare for datasets:\n%s",
                 "\n".join(datasets_to_build))
    logging.info('Version: "%s"', version)
    builders = {
        name: tfds.builder(name, data_dir=FLAGS.data_dir, version=version)
        for name in datasets_to_build
    }

    if FLAGS.builder_config_id is not None:
        # Requesting a single config of a single dataset
        if len(builders) > 1:
            raise ValueError(
                "--builder_config_id can only be used when building a single dataset"
            )
        builder = builders[list(builders.keys())[0]]
        if not builder.BUILDER_CONFIGS:
            raise ValueError(
                "--builder_config_id can only be used with datasets with configs"
            )
        config = builder.BUILDER_CONFIGS[FLAGS.builder_config_id]
        logging.info("Running download_and_prepare for config: %s",
                     config.name)
        builder_for_config = tfds.builder(builder.name,
                                          data_dir=FLAGS.data_dir,
                                          config=config,
                                          version=version)
        download_and_prepare(builder_for_config)
    else:
        for name, builder in builders.items():
            if builder.BUILDER_CONFIGS and "/" not in name:
                # If builder has multiple configs, and no particular config was
                # requested, then compute all.
                for config in builder.BUILDER_CONFIGS:
                    builder_for_config = tfds.builder(builder.name,
                                                      data_dir=FLAGS.data_dir,
                                                      config=config,
                                                      version=version)
                    download_and_prepare(builder_for_config)
            else:
                # If there is a slash in the name, then user requested a specific
                # dataset configuration.
                download_and_prepare(builder)
Beispiel #22
0
def main(argv):
    del argv  # unused arg
    tf.enable_v2_behavior()
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)

    def train_input_fn(ctx):
        """Sets up local (per-core) dataset batching."""
        dataset = utils.load_distributed_dataset(
            split=tfds.Split.TRAIN,
            name=FLAGS.dataset,
            batch_size=FLAGS.per_core_batch_size // FLAGS.num_models,
            drop_remainder=True,
            use_bfloat16=FLAGS.use_bfloat16,
            proportion=FLAGS.train_proportion)
        if ctx and ctx.num_input_pipelines > 1:
            dataset = dataset.shard(ctx.num_input_pipelines,
                                    ctx.input_pipeline_id)
        return dataset

    # No matter what percentage of training proportion, we still evaluate the
    # model on the full test dataset.
    def test_input_fn(ctx):
        """Sets up local (per-core) dataset batching."""
        dataset = utils.load_distributed_dataset(
            split=tfds.Split.TEST,
            name=FLAGS.dataset,
            batch_size=FLAGS.per_core_batch_size // FLAGS.num_models,
            drop_remainder=True,
            use_bfloat16=FLAGS.use_bfloat16)
        if ctx and ctx.num_input_pipelines > 1:
            dataset = dataset.shard(ctx.num_input_pipelines,
                                    ctx.input_pipeline_id)
        return dataset

    train_dataset = strategy.experimental_distribute_datasets_from_function(
        train_input_fn)
    test_dataset = strategy.experimental_distribute_datasets_from_function(
        test_input_fn)
    ds_info = tfds.builder(FLAGS.dataset).info

    batch_size = ((FLAGS.per_core_batch_size // FLAGS.num_models) *
                  FLAGS.num_cores)
    # Train_proportion is a float so need to convert steps_per_epoch to int.
    steps_per_epoch = int(
        (ds_info.splits['train'].num_examples * FLAGS.train_proportion) //
        batch_size)
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size

    if FLAGS.use_bfloat16:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
        tf.keras.mixed_precision.experimental.set_policy(policy)

    with strategy.scope():
        logging.info('Building Keras ResNet-32 model')
        model = batchensemble_model.ensemble_resnet_v1(
            input_shape=ds_info.features['image'].shape,
            depth=32,
            num_classes=ds_info.features['label'].num_classes,
            width_multiplier=4,
            num_models=FLAGS.num_models,
            random_sign_init=FLAGS.random_sign_init,
            dropout_rate=FLAGS.dropout_rate,
            l2=FLAGS.l2)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_schedule = utils.ResnetLearningRateSchedule(steps_per_epoch,
                                                       base_lr, _LR_SCHEDULE)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=0.9,
                                            nesterov=True)
        train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        train_nll = tf.keras.metrics.Mean('train_nll', dtype=tf.float32)
        train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'train_accuracy', dtype=tf.float32)
        test_nll = tf.keras.metrics.Mean('test_nll', dtype=tf.float32)
        test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_accuracy', dtype=tf.float32)
        test_nlls = []
        test_accs = []
        for i in range(FLAGS.num_models):
            test_nlls.append(
                tf.keras.metrics.Mean('test_nll_{}'.format(i),
                                      dtype=tf.float32))
            test_accs.append(
                tf.keras.metrics.SparseCategoricalAccuracy(
                    'test_accuracy_{}'.format(i), dtype=tf.float32))

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries/'))

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            if FLAGS.version2:
                images = tf.tile(images, [FLAGS.num_models, 1, 1, 1])
                labels = tf.tile(labels, [FLAGS.num_models])

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.use_bfloat16:
                    logits = tf.cast(logits, tf.float32)
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)

            # Separate learning rate implementation.
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = []
                for grad, var in zip(grads, model.trainable_variables):
                    # Apply different learning rate on the fast weight approximate
                    # posterior/prior parameters. This is excludes BN and slow weights,
                    # but pay caution to the naming scheme.
                    if ('batch_norm' not in var.name
                            and 'kernel' not in var.name):
                        grads_and_vars.append(
                            (grad * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grads_and_vars.append((grad, var))
                optimizer.apply_gradients(grads_and_vars)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            train_loss.update_state(loss)
            train_nll.update_state(negative_log_likelihood)
            train_accuracy.update_state(labels, logits)

        strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            images = tf.tile(images, [FLAGS.num_models, 1, 1, 1])
            logits = model(images, training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.softmax(logits)
            per_probs = tf.split(probs,
                                 num_or_size_splits=FLAGS.num_models,
                                 axis=0)
            for i in range(FLAGS.num_models):
                member_probs = per_probs[i]
                member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, member_probs)
                test_nlls[i].update_state(member_loss)
                test_accs[i].update_state(labels, member_probs)

            probs = tf.reduce_mean(per_probs, axis=0)

            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))
            test_nll.update_state(negative_log_likelihood)
            test_accuracy.update_state(labels, probs)

        strategy.experimental_run_v2(step_fn, args=(next(iterator), ))

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        with summary_writer.as_default():
            for step in range(steps_per_epoch):
                train_step(train_iterator)

                current_step = epoch * steps_per_epoch + (step + 1)
                max_steps = steps_per_epoch * FLAGS.train_epochs
                time_elapsed = time.time() - start_time
                steps_per_sec = float(current_step) / time_elapsed
                eta_seconds = (max_steps - current_step) / steps_per_sec
                message = (
                    '{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                    'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                        current_step / max_steps, epoch + 1,
                        FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                        time_elapsed / 60))
                if step % 20 == 0:
                    logging.info(message)

            tf.summary.scalar('train/loss',
                              train_loss.result(),
                              step=epoch + 1)
            tf.summary.scalar('train/negative_log_likelihood',
                              train_nll.result(),
                              step=epoch + 1)
            tf.summary.scalar('train/accuracy',
                              train_accuracy.result(),
                              step=epoch + 1)
            logging.info('Train Loss: %s, Accuracy: %s%%',
                         round(float(train_loss.result()), 4),
                         round(float(train_accuracy.result() * 100), 2))

            train_loss.reset_states()
            train_nll.reset_states()
            train_accuracy.reset_states()

            test_iterator = iter(test_dataset)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_step(test_iterator)
            tf.summary.scalar('test/negative_log_likelihood',
                              test_nll.result(),
                              step=epoch + 1)
            tf.summary.scalar('test/accuracy',
                              test_accuracy.result(),
                              step=epoch + 1)
            logging.info('Test NLL: %s, Accuracy: %s%%',
                         round(float(test_nll.result()), 4),
                         round(float(test_accuracy.result() * 100), 2))

            test_nll.reset_states()
            test_accuracy.reset_states()

            for i in range(FLAGS.num_models):
                tf.summary.scalar('test/ensemble_nll_member{}'.format(i),
                                  test_nlls[i].result(),
                                  step=epoch + 1)
                tf.summary.scalar('test/ensemble_accuracy_member{}'.format(i),
                                  test_accs[i].result(),
                                  step=epoch + 1)
                logging.info('Member %d Test loss: %s, accuracy: %s%%', i,
                             round(float(test_nlls[i].result()), 4),
                             round(float(test_accs[i].result() * 100), 2))
                test_nlls[i].reset_states()
                test_accs[i].reset_states()

        if (epoch + 1) % 20 == 0:
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)
Beispiel #23
0
def load_dataset(split,
                 batch_size,
                 name,
                 use_bfloat16,
                 normalize=True,
                 drop_remainder=True,
                 repeat=False,
                 proportion=1.0,
                 data_dir=None):
  """Loads CIFAR dataset for training or testing.

  Args:
    split: tfds.Split.
    batch_size: The global batch size to use.
    name: A string indicates whether it is cifar10 or cifar100.
    use_bfloat16: data type, bfloat16 precision or float32.
    normalize: Whether to apply mean-std normalization on features.
    drop_remainder: bool.
    repeat: bool.
    proportion: float, the proportion of dataset to be used.
    data_dir: Directory where the dataset is stored to be loaded via tfds.load.
      Optional, useful for loading datasets stored on GCS.

  Returns:
    Input function which returns a locally-sharded dataset batch.
  """
  if use_bfloat16:
    dtype = tf.bfloat16
  else:
    dtype = tf.float32
  ds_info = tfds.builder(name).info
  image_shape = ds_info.features['image'].shape
  dataset_size = ds_info.splits['train'].num_examples

  def preprocess(image, label):
    """Image preprocessing function."""
    if split == tfds.Split.TRAIN:
      image = tf.image.resize_with_crop_or_pad(
          image, image_shape[0] + 4, image_shape[1] + 4)
      image = tf.image.random_crop(image, image_shape)
      image = tf.image.random_flip_left_right(image)

    image = tf.image.convert_image_dtype(image, dtype)
    if normalize:
      mean = tf.constant([0.4914, 0.4822, 0.4465], dtype=dtype)
      std = tf.constant([0.2023, 0.1994, 0.2010], dtype=dtype)
      image = (image - mean) / std
    label = tf.cast(label, dtype)
    return image, label

  if proportion == 1.0:
    dataset = tfds.load(
        name, split=split, data_dir=data_dir, as_supervised=True)
  else:
    new_name = '{}:3.*.*'.format(name)
    if split == tfds.Split.TRAIN:
      # use round instead of floor to resolve bug when e.g. using
      # proportion = 1 - 0.8 = 0.19999999
      new_split = 'train[:{}%]'.format(round(100 * proportion))
    elif split == tfds.Split.VALIDATION:
      new_split = 'train[-{}%:]'.format(round(100 * proportion))
    elif split == tfds.Split.TEST:
      new_split = 'test[:{}%]'.format(round(100 * proportion))
    else:
      raise ValueError('Provide valid split.')
    dataset = tfds.load(new_name, split=new_split, as_supervised=True)
  if split == tfds.Split.TRAIN or repeat:
    dataset = dataset.shuffle(buffer_size=dataset_size).repeat()

  dataset = dataset.map(preprocess,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
  dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset
Beispiel #24
0
def train_and_evaluate(config, workdir, vocab_filepath):
    """Runs a training and evaluation loop.

  Args:
    config: Model and training configuration.
    workdir: Working directory for checkpoints and Tensorboard summaries. If
      this contains a checkpoint, training will be resumed from the latest
      checkpoint.
    vocab_filepath: Absolute path to SentencePiece vocab model.

  Raises:
    ValueError: If training or eval batch sizes won't fit number of processes
      and devices, or config is underspecified.
  """
    n_processes = jax.process_count()  # Number of processes
    n_devices = jax.local_device_count()  # Number of local devices per process

    if config.train_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Training batch size must be divisible by the total number of devices, "
            "but training batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.train_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    if config.eval_batch_size % (n_processes * n_devices) > 0:
        raise ValueError(
            "Eval batch size must be divisible by the total number of devices, "
            "but eval batch size = %d, while total number of devices = %d "
            "(%d processes, each with %d devices)" %
            (config.eval_batch_size, n_processes * n_devices, n_processes,
             n_devices))

    per_process_train_batch_size = config.train_batch_size // n_processes
    per_process_eval_batch_size = config.eval_batch_size // n_processes

    if jax.process_index() == 0:
        train_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "train"))
        eval_summary_writer = tensorboard.SummaryWriter(
            os.path.join(workdir, "eval"))
    else:
        train_summary_writer = None
        eval_summary_writer = None

    rng = random.PRNGKey(config.seed)
    rng, init_rng = random.split(rng)

    ds_info = tfds.builder(config.dataset_name).info
    num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples

    num_train_steps = int(num_train_examples * config.num_train_epochs //
                          config.train_batch_size)
    num_warmup_steps = int(config.warmup_proportion * num_train_steps)
    # Round up evaluation frequency to power of 10.
    eval_frequency = int(
        math.ceil(config.eval_proportion * num_train_steps / 10)) * 10

    is_regression_task = config.dataset_name == "glue/stsb"

    num_classes = (1 if is_regression_task else
                   ds_info.features["label"].num_classes)

    tokenizer = spm.SentencePieceProcessor()
    tokenizer.Load(vocab_filepath)
    with config.unlocked():
        config.vocab_size = tokenizer.GetPieceSize()

    frozen_config = ml_collections.FrozenConfigDict(config)
    model = models.SequenceClassificationModel(config=frozen_config,
                                               n_classes=num_classes)

    params = _init_params(model, init_rng, config)

    optimizer = _create_adam_optimizer(config.learning_rate, params)

    # In case current job restarts, ensure that we continue from where we left
    # off.
    optimizer = checkpoints.restore_checkpoint(workdir, optimizer)
    start_step = int(optimizer.state.step)

    # Otherwise, try to restore optimizer and model state from config checkpoint.
    if (start_step == 0 and "init_checkpoint_dir" in config
            and config.init_checkpoint_dir):
        optimizer = _restore_pretrained_model(optimizer, params, config)

    # We access model state only from optimizer via optimizer.target.
    del params

    optimizer = jax_utils.replicate(optimizer)

    if is_regression_task:
        compute_stats = functools.partial(_compute_regression_stats,
                                          model=model,
                                          pad_id=tokenizer.pad_id())
    else:
        compute_stats = functools.partial(_compute_classification_stats,
                                          model=model,
                                          pad_id=tokenizer.pad_id())

    learning_rate_fn = train_utils.create_learning_rate_scheduler(
        factors="constant * linear_warmup * linear_decay",
        base_learning_rate=config.learning_rate,
        warmup_steps=num_warmup_steps,
        decay_steps=num_train_steps - num_warmup_steps,
    )

    glue_inputs = functools.partial(input_pipeline.glue_inputs,
                                    dataset_name=config.dataset_name,
                                    max_seq_length=config.max_seq_length,
                                    tokenizer=tokenizer)
    train_ds = glue_inputs(split=tfds.Split.TRAIN,
                           batch_size=per_process_train_batch_size,
                           training=True)
    train_iter = iter(train_ds)

    if config.dataset_name == "glue/mnli":
        # MNLI contains two validation and test datasets.
        split_suffixes = ["_matched", "_mismatched"]
    else:
        split_suffixes = [""]

    # We init the first set of dropout PRNG keys, but update it afterwards inside
    # the main pmap'd training update for performance.
    rngs = random.split(rng, n_devices)

    loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics,
                                            model=model,
                                            pad_id=tokenizer.pad_id())
    p_train_step = jax.pmap(functools.partial(
        train_utils.train_step,
        loss_and_metrics_fn=loss_and_metrics_fn,
        learning_rate_fn=learning_rate_fn),
                            axis_name="batch")
    p_eval_step = jax.pmap(functools.partial(train_utils.eval_step,
                                             metric_fn=compute_stats),
                           axis_name="batch")
    eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name,
                                              is_regression_task)

    train_metrics = []

    logging.info("Starting training loop.")
    logging.info("====================")

    for step in range(start_step, num_train_steps):
        with jax.profiler.StepTraceContext("train", step_num=step):
            train_batch = next(train_iter)
            train_batch = common_utils.shard(train_batch)

            optimizer, train_step_metrics, rngs = p_train_step(optimizer,
                                                               train_batch,
                                                               rng=rngs)
            train_metrics.append(train_step_metrics)

        if ((step > 0 and config.save_checkpoints_steps
             and step % config.save_checkpoints_steps == 0)
                or step == num_train_steps - 1) and jax.process_index() == 0:
            # Save un-replicated optimizer and model state.
            checkpoints.save_checkpoint(workdir,
                                        jax_utils.unreplicate(optimizer),
                                        step,
                                        keep=2)

        # Periodic metric handling.
        if step % eval_frequency != 0 and step < num_train_steps - 1:
            continue

        logging.info("Gathering training metrics at step: %d", step)

        train_metrics = common_utils.get_metrics(train_metrics)
        train_summary = {
            "loss":
            jnp.sum(train_metrics["loss"]) /
            jnp.sum(train_metrics["num_labels"]),
            "learning_rate":
            learning_rate_fn(step)
        }
        if not is_regression_task:
            train_summary["accuracy"] = jnp.sum(
                train_metrics["correct_predictions"]) / jnp.sum(
                    train_metrics["num_labels"])

        if jax.process_index() == 0:
            assert train_summary_writer
            for key, val in train_summary.items():
                train_summary_writer.scalar(key, val, step)
            train_summary_writer.flush()
        # Reset metric accumulation for next evaluation cycle.
        train_metrics = []

        logging.info("Gathering validation metrics at step: %d", step)

        for split_suffix in split_suffixes:
            eval_ds = glue_inputs(split=tfds.Split.VALIDATION + split_suffix,
                                  batch_size=per_process_eval_batch_size,
                                  training=False)

            all_stats = []
            for _, eval_batch in zip(range(config.max_num_eval_steps),
                                     eval_ds):
                all_stats.append(
                    _evaluate(p_eval_step, optimizer.target, eval_batch,
                              n_devices))
            flat_stats = {}
            for k in all_stats[
                    0]:  # All batches of output stats are the same size
                flat_stats[k] = np.concatenate([stat[k] for stat in all_stats],
                                               axis=0)
            eval_summary = eval_metrics_fn(flat_stats)

            if jax.process_index() == 0:
                assert eval_summary_writer
                for key, val in eval_summary.items():
                    eval_summary_writer.scalar(f"{key}{split_suffix}", val,
                                               step)
                eval_summary_writer.flush()
Beispiel #25
0
def get_dataset_from_tf(name: str):
    builder = tfds.builder(name=name)
    builder.download_and_prepare()
    ds = builder.as_dataset(shuffle_files=False)
    ds_numpy = tfds.as_numpy(dataset=ds)
    return ds_numpy
Beispiel #26
0
def main(argv):
    del argv  # unused arg
    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.total_batch_size // FLAGS.ensemble_size
    # Train_proportion is a float so need to convert steps_per_epoch to int.
    steps_per_epoch = int(
        (ds_info.splits['train'].num_examples * FLAGS.train_proportion) //
        batch_size)
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    data_dir = FLAGS.data_dir
    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    train_builder = ub.datasets.get(FLAGS.dataset,
                                    data_dir=data_dir,
                                    download_data=FLAGS.download_data,
                                    split=tfds.Split.TRAIN,
                                    validation_percent=1. -
                                    FLAGS.train_proportion)
    train_dataset = train_builder.load(batch_size=batch_size)
    validation_dataset = None
    steps_per_validation = 0
    if FLAGS.train_proportion < 1.0:
        validation_builder = ub.datasets.get(
            FLAGS.dataset,
            data_dir=data_dir,
            download_data=FLAGS.download_data,
            split=tfds.Split.VALIDATION,
            validation_percent=1. - FLAGS.train_proportion,
            drop_remainder=FLAGS.drop_remainder_for_eval)
        validation_dataset = validation_builder.load(batch_size=batch_size)
        validation_dataset = strategy.experimental_distribute_dataset(
            validation_dataset)
        steps_per_validation = validation_builder.num_examples // batch_size
    clean_test_builder = ub.datasets.get(
        FLAGS.dataset,
        data_dir=data_dir,
        download_data=FLAGS.download_data,
        split=tfds.Split.TEST,
        drop_remainder=FLAGS.drop_remainder_for_eval)
    clean_test_dataset = clean_test_builder.load(batch_size=batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_datasets = {
        'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
    }
    steps_per_epoch = train_builder.num_examples // batch_size
    steps_per_eval = clean_test_builder.num_examples // batch_size
    num_classes = 100 if FLAGS.dataset == 'cifar100' else 10

    if FLAGS.eval_on_ood:
        ood_dataset_names = FLAGS.ood_dataset
        ood_ds, steps_per_ood = ood_utils.load_ood_datasets(
            ood_dataset_names,
            clean_test_builder,
            1 - FLAGS.train_proportion,
            batch_size,
            drop_remainder=FLAGS.drop_remainder_for_eval)
        ood_datasets = {
            name: strategy.experimental_distribute_dataset(ds)
            for name, ds in ood_ds.items()
        }

    if FLAGS.corruptions_interval > 0:
        extra_kwargs = {}
        if FLAGS.dataset == 'cifar100':
            data_dir = FLAGS.cifar100_c_path
        corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
        for corruption_type in corruption_types:
            for severity in range(1, 6):
                dataset = ub.datasets.get(
                    f'{FLAGS.dataset}_corrupted',
                    corruption_type=corruption_type,
                    data_dir=data_dir,
                    severity=severity,
                    split=tfds.Split.TEST,
                    drop_remainder=FLAGS.drop_remainder_for_eval**
                    extra_kwargs).load(batch_size=batch_size)
                test_datasets[f'{corruption_type}_{severity}'] = (
                    strategy.experimental_distribute_dataset(dataset))

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building Keras model')
        model = ub.models.wide_resnet_sngp_be(
            input_shape=(32, 32, 3),
            batch_size=batch_size,
            depth=28,
            width_multiplier=10,
            num_classes=num_classes,
            ensemble_size=FLAGS.ensemble_size,
            random_sign_init=FLAGS.random_sign_init,
            l2=FLAGS.l2,
            use_gp_layer=FLAGS.use_gp_layer,
            gp_input_dim=FLAGS.gp_input_dim,
            gp_hidden_dim=FLAGS.gp_hidden_dim,
            gp_scale=FLAGS.gp_scale,
            gp_bias=FLAGS.gp_bias,
            gp_input_normalization=FLAGS.gp_input_normalization,
            gp_cov_discount_factor=FLAGS.gp_cov_discount_factor,
            gp_cov_ridge_penalty=FLAGS.gp_cov_ridge_penalty,
            use_spec_norm=FLAGS.use_spec_norm,
            spec_norm_iteration=FLAGS.spec_norm_iteration,
            spec_norm_bound=FLAGS.spec_norm_bound)

        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                           for start_epoch_str in FLAGS.lr_decay_epochs]
        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=1.0 -
                                            FLAGS.one_minus_momentum,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'train/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss':
            tf.keras.metrics.Mean(),
            'train/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood':
            tf.keras.metrics.Mean(),
            'test/accuracy':
            tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece':
            rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/stddev':
            tf.keras.metrics.Mean(),
        }
        eval_dataset_splits = ['test']
        if validation_dataset:
            metrics.update({
                'validation/negative_log_likelihood':
                tf.keras.metrics.Mean(),
                'validation/accuracy':
                tf.keras.metrics.SparseCategoricalAccuracy(),
                'validation/ece':
                rm.metrics.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
                'validation/stddev':
                tf.keras.metrics.Mean(),
            })
            eval_dataset_splits += ['validation']
        if FLAGS.eval_on_ood:
            ood_metrics = ood_utils.create_ood_metrics(
                ood_dataset_names, tpr_list=FLAGS.ood_tpr_threshold)
            metrics.update(ood_metrics)
        for i in range(FLAGS.ensemble_size):
            for dataset_split in eval_dataset_splits:
                metrics[
                    f'{dataset_split}/nll_member_{i}'] = tf.keras.metrics.Mean(
                    )
                metrics[f'{dataset_split}/accuracy_member_{i}'] = (
                    tf.keras.metrics.SparseCategoricalAccuracy())
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, 6):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        rm.metrics.ExpectedCalibrationError(
                            num_bins=FLAGS.num_bins))
                    corrupt_metrics['test/stddev_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch
        if FLAGS.saved_model_dir:
            logging.info('Saved model dir : %s', FLAGS.saved_model_dir)
            latest_checkpoint = tf.train.latest_checkpoint(
                FLAGS.saved_model_dir)
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
        if FLAGS.eval_only:
            initial_epoch = FLAGS.train_epochs - 1  # Run just one epoch of eval

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])
            labels = tf.tile(labels, [FLAGS.ensemble_size])

            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if isinstance(logits, (list, tuple)):
                    # If model returns a tuple of (logits, covmat), extract logits
                    logits, _ = logits
                negative_log_likelihood = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(
                        labels, logits, from_logits=True))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            # Separate learning rate implementation.
            if FLAGS.fast_weight_lr_multiplier != 1.0:
                grads_and_vars = []
                for grad, var in zip(grads, model.trainable_variables):
                    # Apply different learning rate on the fast weight approximate
                    # posterior/prior parameters. This is excludes BN and slow weights,
                    # but pay caution to the naming scheme.
                    if ('batch_norm' not in var.name
                            and 'kernel' not in var.name):
                        grads_and_vars.append(
                            (grad * FLAGS.fast_weight_lr_multiplier, var))
                    else:
                        grads_and_vars.append((grad, var))
                optimizer.apply_gradients(grads_and_vars)
            else:
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].add_batch(probs, label=labels)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)

        for _ in tf.range(tf.cast(steps_per_epoch, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_split, dataset_name, num_steps):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']

            logits_list = []
            stddev_list = []

            for i in range(FLAGS.ensemble_size):
                logits = model(images, training=False)
                if isinstance(logits, (list, tuple)):
                    # If model returns a tuple of (logits, covmat), extract both
                    logits, covmat = logits
                    logits = mean_field_logits(
                        logits,
                        covmat,
                        mean_field_factor=FLAGS.gp_mean_field_factor)
                else:
                    covmat = tf.eye(logits.shape[0])

                stddev = tf.sqrt(tf.linalg.diag_part(covmat))

                stddev_list.append(stddev)
                logits_list.append(logits)

                member_probs = tf.nn.softmax(logits)
                member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, member_probs)
                metrics[f'{dataset_split}/nll_member_{i}'].update_state(
                    member_loss)
                metrics[f'{dataset_split}/accuracy_member_{i}'].update_state(
                    labels, member_probs)
            # Logits dimension is (num_samples, batch_size, num_classes).
            logits_list = tf.stack(logits_list, axis=0)
            stddev_list = tf.stack(stddev_list, axis=0)

            stddev = tf.reduce_mean(stddev_list, axis=0)
            probs_list = tf.nn.softmax(logits_list)
            probs = tf.reduce_mean(probs_list, axis=0)

            labels_broadcasted = tf.broadcast_to(
                labels,
                [FLAGS.ensemble_size, tf.shape(labels)[0]])
            log_likelihoods = -tf.keras.losses.sparse_categorical_crossentropy(
                labels_broadcasted, logits_list, from_logits=True)
            negative_log_likelihood = tf.reduce_mean(
                -tf.reduce_logsumexp(log_likelihoods, axis=[0]) +
                tf.math.log(float(FLAGS.ensemble_size)))

            if dataset_name == 'clean':
                metrics[
                    f'{dataset_split}/negative_log_likelihood'].update_state(
                        negative_log_likelihood)
                metrics[f'{dataset_split}/accuracy'].update_state(
                    labels, probs)
                metrics[f'{dataset_split}/ece'].add_batch(probs, label=labels)
                metrics[f'{dataset_split}/stddev'].update_state(stddev)
            elif dataset_name.startswith('ood'):
                ood_labels = 1 - inputs['is_in_distribution']
                if FLAGS.dempster_shafer_ood:
                    ood_scores = ood_utils.DempsterShaferUncertainty(logits)
                else:
                    ood_scores = 1 - tf.reduce_max(probs, axis=-1)

                # Edgecase for if dataset_name contains underscores
                ood_dataset_name = '_'.join(dataset_name.split('_')[1:])
                for name, metric in metrics.items():
                    if ood_dataset_name in name:
                        metric.update_state(ood_labels, ood_scores)
            elif FLAGS.corruptions_interval > 0:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(dataset_name)].add_batch(
                    probs, label=labels)
                corrupt_metrics['test/stddev_{}'.format(
                    dataset_name)].update_state(stddev)

        for _ in tf.range(tf.cast(num_steps, tf.int32)):
            strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        train_step(train_iterator)

        current_step = (epoch + 1) * steps_per_epoch
        max_steps = steps_per_epoch * FLAGS.train_epochs
        time_elapsed = time.time() - start_time
        steps_per_sec = float(current_step) / time_elapsed
        eta_seconds = (max_steps - current_step) / steps_per_sec
        message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                   'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                       current_step / max_steps, epoch + 1, FLAGS.train_epochs,
                       steps_per_sec, eta_seconds / 60, time_elapsed / 60))
        logging.info(message)

        if validation_dataset:
            validation_iterator = iter(validation_dataset)
            test_step(validation_iterator, 'validation', 'clean',
                      steps_per_validation)
        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            logging.info('Starting to run eval at epoch: %s', epoch)
            test_start_time = time.time()
            test_step(test_iterator, 'test', dataset_name, steps_per_eval)
            ms_per_example = (time.time() - test_start_time) * 1e6 / batch_size
            metrics['test/ms_per_example'].update_state(ms_per_example)

            logging.info('Done with testing on %s', dataset_name)

        if FLAGS.eval_on_ood:
            for dataset_name in ood_dataset_names:
                ood_iterator = iter(
                    ood_datasets['ood_{}'.format(dataset_name)])
                logging.info('Calculating OOD on dataset %s', dataset_name)
                logging.info('Running OOD eval at epoch: %s', epoch)
                test_step(ood_iterator, 'test', 'ood_{}'.format(dataset_name),
                          steps_per_ood[dataset_name])

                logging.info('Done with OOD eval on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        for i in range(FLAGS.ensemble_size):
            logging.info(
                'Member %d Test Loss: %.4f, Accuracy: %.2f%%', i,
                metrics['test/nll_member_{}'.format(i)].result(),
                metrics['test/accuracy_member_{}'.format(i)].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        # Metrics from Robustness Metrics (like ECE) will return a dict with a
        # single key/value, instead of a scalar.
        total_results = {
            k: (list(v.values())[0] if isinstance(v, dict) else v)
            for k, v in total_results.items()
        }
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if FLAGS.corruptions_interval > 0:
            for metric in corrupt_metrics.values():
                metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
    with summary_writer.as_default():
        hp.hparams({
            'base_learning_rate': FLAGS.base_learning_rate,
            'one_minus_momentum': FLAGS.one_minus_momentum,
            'l2': FLAGS.l2,
            'gp_mean_field_factor': FLAGS.gp_mean_field_factor,
            'gp_input_dim': FLAGS.gp_input_dim,
            'gp_scale': FLAGS.gp_scale,
            'gp_hidden_dim': FLAGS.gp_hidden_dim,
            'fast_weight_lr_multiplier': FLAGS.fast_weight_lr_multiplier,
            'random_sign_init': FLAGS.random_sign_init,
        })
Beispiel #27
0
def main(argv):
    fmt = '[%(filename)s:%(lineno)s] %(message)s'
    formatter = logging.PythonFormatter(fmt)
    logging.get_absl_handler().setFormatter(formatter)
    del argv  # unused arg

    tf.io.gfile.makedirs(FLAGS.output_dir)
    logging.info('Saving checkpoints at %s', FLAGS.output_dir)
    tf.random.set_seed(FLAGS.seed)

    data_dir = None
    if FLAGS.use_gpu:
        logging.info('Use GPU')
        strategy = tf.distribute.MirroredStrategy()
    else:
        logging.info('Use TPU at %s',
                     FLAGS.tpu if FLAGS.tpu is not None else 'local')
        data_dir = FLAGS.data_dir
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)

    ds_info = tfds.builder(FLAGS.dataset).info
    batch_size = FLAGS.per_core_batch_size * FLAGS.num_cores
    train_dataset_size = (ds_info.splits['train'].num_examples *
                          FLAGS.train_proportion)
    steps_per_epoch = int(train_dataset_size / batch_size)
    logging.info('Steps per epoch %s', steps_per_epoch)
    logging.info('Size of the dataset %s',
                 ds_info.splits['train'].num_examples)
    logging.info('Train proportion %s', FLAGS.train_proportion)
    steps_per_eval = ds_info.splits['test'].num_examples // batch_size
    num_classes = ds_info.features['label'].num_classes

    train_dataset = ub.datasets.get(
        FLAGS.dataset,
        split=tfds.Split.TRAIN,
        download_data=True,
        validation_percent=1. - FLAGS.train_proportion,
        data_dir=data_dir).load(batch_size=batch_size)
    clean_test_dataset = ub.datasets.get(
        FLAGS.dataset, split=tfds.Split.TEST,
        data_dir=data_dir).load(batch_size=batch_size)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_datasets = {
        'clean': strategy.experimental_distribute_dataset(clean_test_dataset),
    }
    if FLAGS.corruptions_interval > 0:
        if FLAGS.dataset == 'cifar100':
            data_dir = FLAGS.cifar100_c_path
        corruption_types, _ = utils.load_corrupted_test_info(FLAGS.dataset)
        for corruption_type in corruption_types:
            for severity in range(1, 6):
                dataset = ub.datasets.get(
                    f'{FLAGS.dataset}_corrupted',
                    corruption_type=corruption_type,
                    severity=severity,
                    split=tfds.Split.TEST,
                    data_dir=data_dir).load(batch_size=batch_size)
                test_datasets[f'{corruption_type}_{severity}'] = (
                    strategy.experimental_distribute_dataset(dataset))

    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.output_dir, 'summaries'))

    with strategy.scope():
        logging.info('Building ResNet model')
        model = ub.models.wide_resnet(
            input_shape=ds_info.features['image'].shape,
            depth=28,
            width_multiplier=10,
            num_classes=num_classes,
            l2=FLAGS.l2,
            hps=_extract_hyperparameter_dictionary(),
            version=2)
        logging.info('Model input shape: %s', model.input_shape)
        logging.info('Model output shape: %s', model.output_shape)
        logging.info('Model number of weights: %s', model.count_params())
        # Linearly scale learning rate and the decay epochs by vanilla settings.
        base_lr = FLAGS.base_learning_rate * batch_size / 128
        lr_decay_epochs = [(int(start_epoch_str) * FLAGS.train_epochs) // 200
                           for start_epoch_str in FLAGS.lr_decay_epochs]
        lr_schedule = ub.schedules.WarmUpPiecewiseConstantSchedule(
            steps_per_epoch,
            base_lr,
            decay_ratio=FLAGS.lr_decay_ratio,
            decay_epochs=lr_decay_epochs,
            warmup_epochs=FLAGS.lr_warmup_epochs)
        optimizer = tf.keras.optimizers.SGD(lr_schedule,
                                            momentum=0.9,
                                            nesterov=True)
        metrics = {
            'train/negative_log_likelihood': tf.keras.metrics.Mean(),
            'train/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
            'train/loss': tf.keras.metrics.Mean(),
            'train/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
            'test/negative_log_likelihood': tf.keras.metrics.Mean(),
            'test/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
            'test/ece': um.ExpectedCalibrationError(num_bins=FLAGS.num_bins),
        }
        if FLAGS.corruptions_interval > 0:
            corrupt_metrics = {}
            for intensity in range(1, 6):
                for corruption in corruption_types:
                    dataset_name = '{0}_{1}'.format(corruption, intensity)
                    corrupt_metrics['test/nll_{}'.format(dataset_name)] = (
                        tf.keras.metrics.Mean())
                    corrupt_metrics['test/accuracy_{}'.format(
                        dataset_name)] = (
                            tf.keras.metrics.SparseCategoricalAccuracy())
                    corrupt_metrics['test/ece_{}'.format(dataset_name)] = (
                        um.ExpectedCalibrationError(num_bins=FLAGS.num_bins))

        checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
        latest_checkpoint = tf.train.latest_checkpoint(FLAGS.output_dir)
        initial_epoch = 0
        if latest_checkpoint:
            # checkpoint.restore must be within a strategy.scope() so that optimizer
            # slot variables are mirrored.
            checkpoint.restore(latest_checkpoint)
            logging.info('Loaded checkpoint %s', latest_checkpoint)
            initial_epoch = optimizer.iterations.numpy() // steps_per_epoch

    @tf.function
    def train_step(iterator):
        """Training StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            with tf.GradientTape() as tape:
                logits = model(images, training=True)
                if FLAGS.label_smoothing == 0.:
                    negative_log_likelihood = tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(
                            labels, logits, from_logits=True))
                else:
                    one_hot_labels = tf.one_hot(tf.cast(labels, tf.int32),
                                                num_classes)
                    negative_log_likelihood = tf.reduce_mean(
                        tf.keras.losses.categorical_crossentropy(
                            one_hot_labels,
                            logits,
                            from_logits=True,
                            label_smoothing=FLAGS.label_smoothing))
                l2_loss = sum(model.losses)
                loss = negative_log_likelihood + l2_loss
                # Scale the loss given the TPUStrategy will reduce sum all gradients.
                scaled_loss = loss / strategy.num_replicas_in_sync

            grads = tape.gradient(scaled_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            probs = tf.nn.softmax(logits)
            metrics['train/ece'].update_state(labels, probs)
            metrics['train/loss'].update_state(loss)
            metrics['train/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['train/accuracy'].update_state(labels, logits)

        strategy.run(step_fn, args=(next(iterator), ))

    @tf.function
    def test_step(iterator, dataset_name):
        """Evaluation StepFn."""
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images = inputs['features']
            labels = inputs['labels']
            logits = model(images, training=False)
            probs = tf.nn.softmax(logits)
            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))

            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(
                    dataset_name)].update_state(labels, probs)

        strategy.run(step_fn, args=(next(iterator), ))

    metrics.update({'test/ms_per_example': tf.keras.metrics.Mean()})
    metrics.update({'train/ms_per_example': tf.keras.metrics.Mean()})

    train_iterator = iter(train_dataset)
    start_time = time.time()
    tb_callback = None
    if FLAGS.collect_profile:
        tb_callback = tf.keras.callbacks.TensorBoard(profile_batch=(100, 102),
                                                     log_dir=os.path.join(
                                                         FLAGS.output_dir,
                                                         'logs'))
        tb_callback.set_model(model)
    for epoch in range(initial_epoch, FLAGS.train_epochs):
        logging.info('Starting to run epoch: %s', epoch)
        for step in range(steps_per_epoch):
            if tb_callback:
                tb_callback.on_train_batch_begin(step)
            train_start_time = time.time()
            train_step(train_iterator)
            ms_per_example = (time.time() -
                              train_start_time) * 1e6 / batch_size
            metrics['train/ms_per_example'].update_state(ms_per_example)

            current_step = epoch * steps_per_epoch + (step + 1)
            max_steps = steps_per_epoch * FLAGS.train_epochs
            time_elapsed = time.time() - start_time
            steps_per_sec = float(current_step) / time_elapsed
            eta_seconds = (max_steps - current_step) / steps_per_sec
            message = ('{:.1%} completion: epoch {:d}/{:d}. {:.1f} steps/s. '
                       'ETA: {:.0f} min. Time elapsed: {:.0f} min'.format(
                           current_step / max_steps, epoch + 1,
                           FLAGS.train_epochs, steps_per_sec, eta_seconds / 60,
                           time_elapsed / 60))
            if step % 20 == 0:
                logging.info(message)
            if tb_callback:
                tb_callback.on_train_batch_end(step)
        datasets_to_evaluate = {'clean': test_datasets['clean']}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            datasets_to_evaluate = test_datasets
        for dataset_name, test_dataset in datasets_to_evaluate.items():
            test_iterator = iter(test_dataset)
            logging.info('Testing on dataset %s', dataset_name)
            for step in range(steps_per_eval):
                if step % 20 == 0:
                    logging.info('Starting to run eval step %s of epoch: %s',
                                 step, epoch)
                test_start_time = time.time()
                test_step(test_iterator, dataset_name)
                ms_per_example = (time.time() -
                                  test_start_time) * 1e6 / batch_size
                metrics['test/ms_per_example'].update_state(ms_per_example)

            logging.info('Done with testing on %s', dataset_name)

        corrupt_results = {}
        if (FLAGS.corruptions_interval > 0
                and (epoch + 1) % FLAGS.corruptions_interval == 0):
            corrupt_results = utils.aggregate_corrupt_metrics(
                corrupt_metrics, corruption_types)

        logging.info('Train Loss: %.4f, Accuracy: %.2f%%',
                     metrics['train/loss'].result(),
                     metrics['train/accuracy'].result() * 100)
        logging.info('Test NLL: %.4f, Accuracy: %.2f%%',
                     metrics['test/negative_log_likelihood'].result(),
                     metrics['test/accuracy'].result() * 100)
        total_results = {
            name: metric.result()
            for name, metric in metrics.items()
        }
        total_results.update(corrupt_results)
        with summary_writer.as_default():
            for name, result in total_results.items():
                tf.summary.scalar(name, result, step=epoch + 1)

        for metric in metrics.values():
            metric.reset_states()

        if (FLAGS.checkpoint_interval > 0
                and (epoch + 1) % FLAGS.checkpoint_interval == 0):
            checkpoint_name = checkpoint.save(
                os.path.join(FLAGS.output_dir, 'checkpoint'))
            logging.info('Saved checkpoint to %s', checkpoint_name)

    final_checkpoint_name = checkpoint.save(
        os.path.join(FLAGS.output_dir, 'checkpoint'))
    logging.info('Saved last checkpoint to %s', final_checkpoint_name)
Beispiel #28
0
def main(_):
    print(len(tfds.builder(FLAGS.dataset).BUILDER_CONFIGS))
Beispiel #29
0
def document_single_builder(builder):
    """Doc string for a single builder, with or without configs."""
    mod_name = builder.__class__.__module__
    cls_name = builder.__class__.__name__
    mod_file = sys.modules[mod_name].__file__
    if mod_file.endswith("pyc"):
        mod_file = mod_file[:-1]

    description_prefix = ""

    if builder.builder_configs:
        # Dataset with configs; document each one
        config_docs = []
        for config in builder.BUILDER_CONFIGS:
            builder = tfds.builder(builder.name, config=config)
            info = builder.info
            # TODO(rsepassi): document the actual config object
            config_doc = SINGLE_CONFIG_ENTRY.format(
                builder_name=builder.name,
                config_name=config.name,
                description=config.description,
                version=config.version,
                feature_information=make_feature_information(info),
                size=tfds.units.size_str(info.size_in_bytes),
            )
            config_docs.append(config_doc)
        out_str = DATASET_WITH_CONFIGS_ENTRY.format(
            snakecase_name=builder.name,
            module_and_class="%s.%s" % (tfds_mod_name(mod_name), cls_name),
            cls_url=cls_url(mod_name),
            config_names="\n".join([
                CONFIG_BULLET.format(
                    name=config.name,
                    description=config.description,
                    version=config.version,
                    size=tfds.units.size_str(
                        tfds.builder(builder.name,
                                     config=config).info.size_in_bytes))
                for config in builder.BUILDER_CONFIGS
            ]),
            config_cls="%s.%s" %
            (tfds_mod_name(mod_name), type(builder.builder_config).__name__),
            configs="\n".join(config_docs),
            urls=format_urls(info.urls),
            url=url_from_info(info),
            supervised_keys=str(info.supervised_keys),
            citation=make_citation(info.citation),
            statistics_information=make_statistics_information(info),
            description=builder.info.description,
            description_prefix=description_prefix,
        )
    else:
        info = builder.info
        out_str = DATASET_ENTRY.format(
            snakecase_name=builder.name,
            module_and_class="%s.%s" % (tfds_mod_name(mod_name), cls_name),
            cls_url=cls_url(mod_name),
            description=info.description,
            description_prefix=description_prefix,
            version=info.version,
            feature_information=make_feature_information(info),
            statistics_information=make_statistics_information(info),
            urls=format_urls(info.urls),
            url=url_from_info(info),
            supervised_keys=str(info.supervised_keys),
            citation=make_citation(info.citation),
            size=tfds.units.size_str(info.size_in_bytes),
        )

    out_str = schema_org(builder) + "\n" + out_str
    return out_str
    def __init__(self,
                 split: str,
                 validation_percent: float = 0.0,
                 shuffle_buffer_size: Optional[int] = 16384,
                 num_parallel_parser_calls: int = 64,
                 try_gcs: bool = False,
                 download_data: bool = False,
                 is_training: Optional[bool] = None,
                 preprocessing_type: str = 'resnet',
                 use_bfloat16: bool = False,
                 normalize_input: bool = False,
                 image_size: int = 224,
                 resnet_preprocessing_resize_method: Optional[str] = None,
                 ensemble_size: int = 1,
                 one_hot: bool = False,
                 mixup_params: Dict[str, Any] = None,
                 run_mixup: bool = False,
                 **unused_kwargs: Dict[str, Any]):
        """Create an ImageNet tf.data.Dataset builder.

    Args:
      split: a dataset split, either a custom tfds.Split or one of the
        tfds.Split enums [TRAIN, VALIDAITON, TEST] or their lowercase string
        names.
      validation_percent: the percent of the training set to use as a validation
        set.
      shuffle_buffer_size: the number of example to use in the shuffle buffer
        for tf.data.Dataset.shuffle().
      num_parallel_parser_calls: the number of parallel threads to use while
        preprocessing in tf.data.Dataset.map().
      try_gcs: Whether or not to try to use the GCS stored versions of dataset
        files.
      download_data: Whether or not to download data before loading.
      is_training: Whether or not the given `split` is the training split. Only
        required when the passed split is not one of ['train', 'validation',
        'test', tfds.Split.TRAIN, tfds.Split.VALIDATION, tfds.Split.TEST].
      preprocessing_type: Which type of preprocessing to apply, either
        'inception' or 'resnet'.
      use_bfloat16: Whether or not to use bfloat16 or float32 images.
      normalize_input: Whether or not to normalize images by the ImageNet mean
        and stddev.
      image_size: The size of the image in pixels.
      resnet_preprocessing_resize_method: Optional string for the resize method
        to use for resnet preprocessing.
      ensemble_size: `int` for number of ensemble members used in Mixup.
      one_hot: whether or not to use one-hot labels.
      mixup_params: hparams of mixup.
      run_mixup: An explicit flag of whether or not to run mixup if
        `mixup_params['mixup_alpha'] > 0`. By default, mixup will only be run in
        training mode if `mixup_params['mixup_alpha'] > 0`.
      **unused_kwargs: Ignored.
    """
        name = 'imagenet2012'
        dataset_builder = tfds.builder(name, try_gcs=try_gcs)
        if is_training is None:
            is_training = split in ['train', tfds.Split.TRAIN]
        new_split = base.get_validation_percent_split(
            dataset_builder,
            validation_percent,
            split,
            test_split=tfds.Split.VALIDATION)
        if preprocessing_type == 'inception':
            decoders = {
                'image': tfds.decode.SkipDecoding(),
            }
        else:
            decoders = None
        super(ImageNetDataset, self).__init__(
            name=name,
            dataset_builder=dataset_builder,
            split=new_split,
            is_training=is_training,
            shuffle_buffer_size=shuffle_buffer_size,
            num_parallel_parser_calls=num_parallel_parser_calls,
            fingerprint_key='file_name',
            download_data=download_data,
            decoders=decoders)
        self._preprocessing_type = preprocessing_type
        self._use_bfloat16 = use_bfloat16
        self._normalize_input = normalize_input
        self._image_size = image_size
        self._resnet_preprocessing_resize_method = resnet_preprocessing_resize_method
        self._run_mixup = run_mixup

        self.ensemble_size = ensemble_size
        self._one_hot = one_hot
        if mixup_params is None:
            mixup_params = {}
        self._mixup_params = mixup_params