def input_shape():
    return (
        csf.distribution.replica_batch_size(),
        FLAGS.model_tilesize,
        FLAGS.model_tilesize,
        gf.n_bands(),
    )
Ejemplo n.º 2
0
def data_shape():
    """Get the shape of a single batch of input imagery."""
    return (
        csf.distribution.replica_batch_size(),
        FLAGS.data_tilesize,
        FLAGS.data_tilesize,
        gf.n_bands(),
    )
def _create_view(scene, dropout_rate, seed=None):
    """
    Apply augmentation to a set of input imagery, creating a new view.
    Note that this function is autograph-traced and takes a Python integer input (seed),
    so keep the number of calls with distinct seeds to a minimum.
    Do not pass Python values to any other argument.

    Parameters
    ----------
    scene : tf.Tensor
        A tensor of aligned input imagery.
    dropout_rate : None or tf.Tensor
        A scalar, float Tensor holding the current dropout rate.
        Included as an argument to work well with scheduling and autograph.
    seed : int, optional
        Random seed to use. Used to ensure that views get different random numbers.

    Returns
    -------
    tf.Tensor
        A view of the input imagery with crop, band dropout, and jitter applied.
    """
    seed = seed or FLAGS.random_seed

    if FLAGS.model_tilesize != FLAGS.data_tilesize:
        scene = tf.image.random_crop(scene, input_shape(), name="crop", seed=seed)
    if FLAGS.random_brightness_delta:
        scene = tf.image.random_brightness(
            scene, FLAGS.random_brightness_delta, seed=seed
        )
    if FLAGS.random_contrast_delta:
        scene = tf.image.random_contrast(
            scene,
            1.0 - FLAGS.random_contrast_delta,
            1.0 + FLAGS.random_contrast_delta,
            seed=seed,
        )
    scene = tf.nn.dropout(
        scene,
        dropout_rate,
        noise_shape=(csf.distribution.replica_batch_size(), 1, 1, gf.n_bands()),
        name="band_dropout",
        seed=seed,
    )

    if FLAGS.flips:
        scene = tf.image.random_flip_up_down(scene, seed=seed)
        scene = tf.image.random_flip_left_right(scene, seed=seed)

    if FLAGS.rotation:
        scene = tf.image.rot90(
            scene,
            k=tf.random.uniform(
                shape=(), minval=0, maxval=3, dtype=tf.dtypes.int32, seed=seed
            ),
        )

    return scene
Ejemplo n.º 4
0
def load_osm_dataset(tfrecord_glob, band_indices):
    """
    Load the OpenStreetMap classification dataset.

    Parameters
    ----------
    tfrecord_glob : string
        A glob specifying the path of files used to load the dataset.
    band_indices : [int]
        List of input bands to keep.

    Returns
    -------
    tf.data.Dataset
        The dataset, yielding (image, label) pairs.
    """
    features = {
        "spot_naip_phr":
        tf.io.FixedLenSequenceFeature([], dtype=tf.string, allow_missing=True),
        "label":
        tf.io.FixedLenSequenceFeature([], dtype=tf.int64, allow_missing=True),
    }

    input_shape = (OSM_TILESIZE, OSM_TILESIZE, gf.n_bands())
    target_shape = (len(OSM_CLASSES), )

    def _parse_image_function(example_proto):
        example_features = tf.io.parse_single_example(example_proto, features)
        image = tf.reshape(
            tf.io.decode_raw(example_features["spot_naip_phr"], tf.uint8),
            input_shape)
        bands_to_keep = list()
        for index in band_indices:
            bands_to_keep.append(tf.expand_dims(image[..., index], axis=-1))
        image = tf.concat(bands_to_keep, axis=-1)
        target = tf.reshape(
            tf.one_hot(example_features["label"], depth=len(OSM_CLASSES)),
            target_shape)
        image = (tf.cast(image, tf.dtypes.float32) / 128.0) - 1.0
        return image, target

    options = tf.data.Options()
    options.experimental_deterministic = False
    options.experimental_optimization.parallel_batch = True
    options.experimental_optimization.map_vectorization.enabled = True

    return (tf.data.Dataset.list_files(tfrecord_glob).interleave(
        tf.data.TFRecordDataset).with_options(options).map(
            _parse_image_function).cache())
def encoder_head(
    size,
    bands=None,
    batch_size=8,
    input_scaling=1.0,
    checkpoint=None,
    trainable=True,
    assert_checkpoint=False,
):
    """
    Build a ResNet encoder from a subset of available bands, re-ordering the bands
    correctly and filling in any missing bands with zeroes.

    Parameters
    ----------
    size : int
        Tilesize this encoder accepts.
    bands : [string], optional
        List of bands this encoder uses. All other bands are filled in with zeroes.
        May appear in any order, but must be a subset of FLAGS.bands. If None,
        use all available bands.
    batch_size : int, optional
        Batch size for the encoder.
    input_scaling : float, optional
        Number to multiply inputs by.
        When using CSF weights, should be 1 / (1 - final dropout rate used in training).
    checkpoint : str, optional
         - If "imagenet", the encoder is initialized with ImageNet weights and must
           have exactly 3 bands.
         - If "random" then initialize randomly.
         - If a path to a checkpoint file (locally or in Google cloud), initialize
           from that checkpoint.
         - If a path to a directory containing checkpoint files (locally or in Google
           cloud), initialize from the latest checkpoint in that directory.
    trainable : bool, optional
        If False, the encoder is frozen.
    assert_checkpoint : bool, optional
        If True, require that a checkpoint is loaded.

    Returns
    -------
    tf.Tensor
        The overall model inputs.
    tf.Tensor
        The encoder inputs.
    tf.Tensor
        The encoder outputs.
    """

    if bands is None:
        bands = FLAGS.bands
    n_bands = len(bands)

    if n_bands <= 0:
        raise ValueError("You must provide some bands.")

    # Load upstream weights into encoder
    if checkpoint == "imagenet":
        if n_bands != 3:
            raise ValueError(
                "If initializing an encoder with ImageNet weights, you"
                "must provide exactly three bands."
            )
        n_input_bands = 3
        encoder = resnet_encoder(3, weights="imagenet")
    elif checkpoint == "random":
        n_input_bands = gf.n_bands()
        encoder = resnet_encoder(n_input_bands, weights=None)
    else:
        n_input_bands = gf.n_bands()
        encoder = resnet_encoder(n_input_bands)

        ckpt = tf.train.Checkpoint(encoder=encoder)
        ckpt_path = tf.train.latest_checkpoint(checkpoint) or checkpoint
        ckpt_restore_status = ckpt.restore(ckpt_path)

        if assert_checkpoint:
            ckpt_restore_status.assert_nontrivial_match().expect_partial()
        else:
            ckpt_restore_status.expect_partial()

    encoder.trainable = trainable
    model_inputs = Input(batch_shape=(batch_size, size, size, n_bands))

    # Reorganize the present inputs according to the order given
    to_concat = list()
    if n_input_bands == 3:  # Order RGB bands correctly for ImageNet experiments
        for rgb_band in ("red", "green", "blue"):
            for band_i, band in enumerate(bands):
                if band.endswith(rgb_band):
                    to_concat.append(K.expand_dims(model_inputs[..., band_i], axis=-1))
                    break
            else:
                to_concat.append(K.zeros(shape=(batch_size, size, size, 1)))
    else:
        for default_band in FLAGS.bands:
            for band_i, band in enumerate(bands):
                if band == default_band:
                    to_concat.append(K.expand_dims(model_inputs[..., band_i], axis=-1))
                    break
            else:
                to_concat.append(K.zeros(shape=(batch_size, size, size, 1)))
    all_inputs = Concatenate(axis=-1)(to_concat)

    # Multiply inputs according to missing bands
    scaled_inputs = Lambda(lambda x: x * input_scaling)(all_inputs)
    encoded = encoder(scaled_inputs)

    return model_inputs, scaled_inputs, encoded
def run_unsupervised_training():
    """
    Perform a full unsupervised training run programmatically.
    """
    logging.info(
        "Starting unsupervised training run with flags:\n{}".format(
            FLAGS.flags_into_string()
        )
    )
    csf.distribution.initialize()
    data_shape = csf.data.data_shape()
    layers_and_weights = layer_loss_weights().items()

    logging.debug("Building dataset.")
    with csf.distribution.tpu_worker_context():
        dataset = csf.distribution.distribute_dataset_fn(csf.data.load_dataset)
        dataset_iterator = iter(dataset)

    with csf.distribution.tpu_worker_context(), csf.distribution.distributed_context():
        tf.random.set_seed(FLAGS.random_seed)

        logging.debug("Creating schedules.")
        learning_rate = tf.Variable(
            0.0,
            trainable=False,
            name="learning_rate",
            dtype=tf.dtypes.float32,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        )
        dropout_rate = tf.Variable(
            0.0,
            trainable=False,
            name="dropout_rate_rate",
            dtype=tf.dtypes.float32,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        )

        logging.debug("Building model and optimizer.")
        encoder = resnet_encoder(gf.n_bands())
        encoder_vars = encoder.trainable_variables
        optimizer = tf.optimizers.Adam(learning_rate, clipnorm=FLAGS.gradient_clipnorm)

        logging.debug("Building checkpoint objects.")
        ckpt = tf.train.Checkpoint(encoder=encoder, optimizer=optimizer)
        ckpt_manager = tf.train.CheckpointManager(
            ckpt,
            FLAGS.out_dir,
            FLAGS.max_checkpoints,
            FLAGS.keep_checkpoint_every_n_hours,
        )
        restore_ckpt = FLAGS.initial_checkpoint or ckpt_manager.latest_checkpoint
        if restore_ckpt:
            logging.info("Continuing training from checkpoint: {}".format(restore_ckpt))
            ckpt.restore(restore_ckpt)
        else:
            logging.info("Initializing encoder and optimizer from scratch.")

        logging.debug("Building metrics.")
        summary_writer = tf.summary.create_file_writer(FLAGS.out_dir)
        loss_metric = tf.metrics.Mean("loss", tf.dtypes.float32)
        layer_loss_metrics = {
            layer: tf.metrics.Mean("{}_loss".format(layer), tf.dtypes.float32)
            for layer in layer_loss_weights().keys()
        }
        layer_accuracy_metrics = {
            layer: tf.metrics.Mean("{}_accuracy".format(layer), tf.dtypes.float32)
            for layer in layer_loss_weights().keys()
        }
        layer_rep_scale_metrics = {
            layer: tf.metrics.Mean(
                "{}_representation_size".format(layer), tf.dtypes.float32
            )
            for layer in layer_loss_weights().keys()
        }
        all_metrics = (
            [loss_metric]
            + list(layer_loss_metrics.values())
            + list(layer_accuracy_metrics.values())
            + list(layer_rep_scale_metrics.values())
        )

        def write_metrics(step):
            with summary_writer.as_default():
                tf.summary.scalar("learning_rate", learning_rate, step)
                tf.summary.scalar("band_dropout_rate", dropout_rate, step)

                for metric in all_metrics:
                    tf.summary.scalar(metric.name, metric.result(), step)
                    metric.reset_states()

            summary_writer.flush()

        logging.debug("Building distributed execution functions.")

        @csf.distribution.distribute_computation
        def _replicated_training_step(batch):
            batch = tf.reshape(batch, data_shape)
            with tf.name_scope("training_step"):
                with tf.name_scope("view_1"):
                    view_1 = _create_view(batch, dropout_rate, seed=1)
                with tf.name_scope("view_2"):
                    view_2 = _create_view(batch, dropout_rate, seed=2)

                losses = []

                with tf.GradientTape() as tape:
                    representations_1 = encoder(view_1)
                    representations_2 = encoder(view_2)

                    for layer, weight in layers_and_weights:
                        with tf.name_scope("layer_{}".format(layer)):
                            rep_1 = representations_1[layer]
                            rep_2 = representations_2[layer]
                            loss, accuracy = _contrastive_loss(rep_1, rep_2)

                            # Plot the average 2-norm of representations
                            with tf.name_scope("compute_scale"):
                                rep_flat = tf.reshape(
                                    rep_1, (csf.distribution.replica_batch_size(), -1)
                                )
                                rep_norms = tf.norm(rep_flat, axis=-1)
                                layer_rep_scale_metrics[layer].update_state(rep_norms)

                            losses.append(weight * loss)
                            layer_loss_metrics[layer].update_state(loss)
                            layer_accuracy_metrics[layer].update_state(accuracy)

                    loss_total = tf.reduce_sum(losses, name="loss_total")

                gradients = tape.gradient(loss_total, encoder_vars)
                optimizer.apply_gradients(zip(gradients, encoder_vars))
                loss_metric.update_state(loss_total)

        @tf.function
        def train_steps(iter_, steps):
            for _ in tf.range(steps):
                _replicated_training_step((next(iter_),))

        logging.info("Beginning unsupervised training.")
        while True:
            step = optimizer.iterations.numpy()
            logging.info("Starting step: {}".format(step))

            learning_rate.assign(
                csf.utils.optional_warmup(
                    step, FLAGS.learning_rate, FLAGS.learning_rate_warmup_batches
                )
            )
            dropout_rate.assign(
                csf.utils.optional_warmup(
                    step,
                    FLAGS.band_dropout_rate,
                    FLAGS.band_dropout_rate_warmup_batches,
                )
            )
            train_steps(
                dataset_iterator,
                tf.convert_to_tensor(FLAGS.callback_frequency, dtype=tf.int32),
            )
            write_metrics(step)
            if (step // FLAGS.callback_frequency) % FLAGS.checkpoint_frequency == 0:
                ckpt_path = ckpt_manager.save()
                logging.info("Saved checkpoint at path: {}.".format(ckpt_path))

            if FLAGS.max_batches and step >= FLAGS.max_batches:
                break

        ckpt_path = ckpt_manager.save()
        logging.info("Done with unsupervised training.")
        logging.info("Saving final checkpoint at path: {}.".format(ckpt_path))
Ejemplo n.º 7
0
def load_dataset(input_context=None):
    """
    Load a dataset suitable for unsupervised training in a distributed environment:
        - If `input_context` is provided, it is used to build a single-replica dataset.
        - If `input_context` is not provided, builds a cross-replica dataset.

    In some cases (e.g. multi-GPU) a single dataset can be copied to multiple devices,
    so a cross-replica dataset should be built. This is also the default for
    single-device training. In other cases (e.g. TPU pods or clusters) each device
    must build its own single-replica dataset by calling this function.

    For most purposes it suffices to call this function with `input_context=None`.
    Do not manually pass an input context; instead create per-replica datasets with
    `tf.distribute.experimental_distribute_datasets_from_function`.

    Parameters
    ----------
    input_context : None or tf.distribute.InputContext, optional
        If None, build a cross-replica, whole-dataset pipeline, as normal.
        If provided, shard the dataset to a single replica device.

    Returns
    -------
    tf.data.Dataset
        A dataset of batched examples, where each example is a set of coterminous
        input bands. Examples are flattened for efficient communication with hardware.
    """
    image_dims = (FLAGS.data_tilesize**2) * gf.n_bands()

    # NOTE: Make sure to test and fine-tune these optimizations for any new hardware.
    options = tf.data.Options()
    options.experimental_deterministic = False
    if FLAGS.enable_experimental_optimization:
        options.experimental_optimization.autotune_buffers = True
        options.experimental_optimization.autotune_cpu_budget = True
        options.experimental_optimization.parallel_batch = True
        options.experimental_optimization.map_fusion = True
        options.experimental_optimization.map_vectorization.enabled = False

    # If input_context is not provided, the dataset will be built once and copied to
    # each replica. Sharding and distribution will be handled automatically by the
    # distribution strategy, so we batch by global batch size.
    if input_context is not None and input_context.num_input_pipelines > 1:
        shard_data = True
        batch_size = input_context.get_per_replica_batch_size(FLAGS.batch_size)

    # Otherwise, _this function_ is copied to each replica and called once independently
    # per copy. The function should return an appropriate dataset for the replica --
    # manually sharded to be non-overlapping with other replicas and batched to
    # the replica batch size.
    else:
        shard_data = False
        batch_size = FLAGS.batch_size

    if FLAGS.data_file:
        dataset = tf.data.TFRecordDataset(
            FLAGS.data_file).with_options(options)

        if shard_data:  # Shard granularity: lines of tfrecords
            dataset = dataset.shard(input_context.num_input_pipelines,
                                    input_context.input_pipeline_id)

        dataset = dataset.repeat()
    else:
        dataset = (tf.data.experimental.CsvDataset(
            FLAGS.data_listing, [tf.dtypes.string]).with_options(options).map(
                lambda x: tf.reshape(x, []),
                num_parallel_calls=tf.data.experimental.AUTOTUNE,
            ))

        if shard_data:  # Shard granularity: full tfrecord files
            dataset = dataset.shard(input_context.num_input_pipelines,
                                    input_context.input_pipeline_id)

        dataset = dataset.repeat().interleave(
            tf.data.TFRecordDataset,
            cycle_length=FLAGS.tfrecord_parallel_reads,
            block_length=FLAGS.tfrecord_sequential_reads,
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
        )

    feature_spec = {
        FLAGS.data_feature_name: tf.io.FixedLenFeature([], tf.dtypes.string)
    }

    def preprocess(batch):
        batch = tf.io.parse_example(batch, feature_spec)
        batch = tf.io.decode_raw(batch[FLAGS.data_feature_name],
                                 tf.dtypes.uint8,
                                 fixed_length=image_dims)
        batch = tf.cast(batch, tf.dtypes.float32)
        batch = (batch - 128.0) / 128.0
        return batch

    def augment(batch):
        which_aug = tf.random.uniform(shape=(),
                                      dtype=tf.dtypes.int32,
                                      minval=0,
                                      maxval=8)
        aug_options = {
            0: lambda: batch,
            1: lambda: tf.transpose(batch, perm=[0, 2, 1, 3]),
            2: lambda: tf.image.flip_up_down(batch),
            3: lambda: tf.image.flip_left_right(tf.image.rot90(batch, k=1)),
            4: lambda: tf.image.flip_left_right(batch),
            5: lambda: tf.image.rot90(batch, k=1),
            6: lambda: tf.image.rot90(batch, k=2),
            7: lambda: tf.image.rot90(batch, k=3),
        }
        batch = tf.switch_case(which_aug, aug_options)
        return batch

    if FLAGS.shuffle_buffer_size > 0:
        dataset = dataset.shuffle(
            FLAGS.shuffle_buffer_size,
            reshuffle_each_iteration=True,
            seed=FLAGS.random_seed,
        )

    # NOTE: examples are provided "flattened" and must be reshaped into images
    dataset = dataset.batch(batch_size, drop_remainder=True).map(
        preprocess, tf.data.experimental.AUTOTUNE)

    if FLAGS.enable_augmentation:
        dataset = dataset.map(augment, tf.data.experimental.AUTOTUNE)

    return dataset.prefetch(FLAGS.prefetch_batches)