def input_shape(): return ( csf.distribution.replica_batch_size(), FLAGS.model_tilesize, FLAGS.model_tilesize, gf.n_bands(), )
def data_shape(): """Get the shape of a single batch of input imagery.""" return ( csf.distribution.replica_batch_size(), FLAGS.data_tilesize, FLAGS.data_tilesize, gf.n_bands(), )
def _create_view(scene, dropout_rate, seed=None): """ Apply augmentation to a set of input imagery, creating a new view. Note that this function is autograph-traced and takes a Python integer input (seed), so keep the number of calls with distinct seeds to a minimum. Do not pass Python values to any other argument. Parameters ---------- scene : tf.Tensor A tensor of aligned input imagery. dropout_rate : None or tf.Tensor A scalar, float Tensor holding the current dropout rate. Included as an argument to work well with scheduling and autograph. seed : int, optional Random seed to use. Used to ensure that views get different random numbers. Returns ------- tf.Tensor A view of the input imagery with crop, band dropout, and jitter applied. """ seed = seed or FLAGS.random_seed if FLAGS.model_tilesize != FLAGS.data_tilesize: scene = tf.image.random_crop(scene, input_shape(), name="crop", seed=seed) if FLAGS.random_brightness_delta: scene = tf.image.random_brightness( scene, FLAGS.random_brightness_delta, seed=seed ) if FLAGS.random_contrast_delta: scene = tf.image.random_contrast( scene, 1.0 - FLAGS.random_contrast_delta, 1.0 + FLAGS.random_contrast_delta, seed=seed, ) scene = tf.nn.dropout( scene, dropout_rate, noise_shape=(csf.distribution.replica_batch_size(), 1, 1, gf.n_bands()), name="band_dropout", seed=seed, ) if FLAGS.flips: scene = tf.image.random_flip_up_down(scene, seed=seed) scene = tf.image.random_flip_left_right(scene, seed=seed) if FLAGS.rotation: scene = tf.image.rot90( scene, k=tf.random.uniform( shape=(), minval=0, maxval=3, dtype=tf.dtypes.int32, seed=seed ), ) return scene
def load_osm_dataset(tfrecord_glob, band_indices): """ Load the OpenStreetMap classification dataset. Parameters ---------- tfrecord_glob : string A glob specifying the path of files used to load the dataset. band_indices : [int] List of input bands to keep. Returns ------- tf.data.Dataset The dataset, yielding (image, label) pairs. """ features = { "spot_naip_phr": tf.io.FixedLenSequenceFeature([], dtype=tf.string, allow_missing=True), "label": tf.io.FixedLenSequenceFeature([], dtype=tf.int64, allow_missing=True), } input_shape = (OSM_TILESIZE, OSM_TILESIZE, gf.n_bands()) target_shape = (len(OSM_CLASSES), ) def _parse_image_function(example_proto): example_features = tf.io.parse_single_example(example_proto, features) image = tf.reshape( tf.io.decode_raw(example_features["spot_naip_phr"], tf.uint8), input_shape) bands_to_keep = list() for index in band_indices: bands_to_keep.append(tf.expand_dims(image[..., index], axis=-1)) image = tf.concat(bands_to_keep, axis=-1) target = tf.reshape( tf.one_hot(example_features["label"], depth=len(OSM_CLASSES)), target_shape) image = (tf.cast(image, tf.dtypes.float32) / 128.0) - 1.0 return image, target options = tf.data.Options() options.experimental_deterministic = False options.experimental_optimization.parallel_batch = True options.experimental_optimization.map_vectorization.enabled = True return (tf.data.Dataset.list_files(tfrecord_glob).interleave( tf.data.TFRecordDataset).with_options(options).map( _parse_image_function).cache())
def encoder_head( size, bands=None, batch_size=8, input_scaling=1.0, checkpoint=None, trainable=True, assert_checkpoint=False, ): """ Build a ResNet encoder from a subset of available bands, re-ordering the bands correctly and filling in any missing bands with zeroes. Parameters ---------- size : int Tilesize this encoder accepts. bands : [string], optional List of bands this encoder uses. All other bands are filled in with zeroes. May appear in any order, but must be a subset of FLAGS.bands. If None, use all available bands. batch_size : int, optional Batch size for the encoder. input_scaling : float, optional Number to multiply inputs by. When using CSF weights, should be 1 / (1 - final dropout rate used in training). checkpoint : str, optional - If "imagenet", the encoder is initialized with ImageNet weights and must have exactly 3 bands. - If "random" then initialize randomly. - If a path to a checkpoint file (locally or in Google cloud), initialize from that checkpoint. - If a path to a directory containing checkpoint files (locally or in Google cloud), initialize from the latest checkpoint in that directory. trainable : bool, optional If False, the encoder is frozen. assert_checkpoint : bool, optional If True, require that a checkpoint is loaded. Returns ------- tf.Tensor The overall model inputs. tf.Tensor The encoder inputs. tf.Tensor The encoder outputs. """ if bands is None: bands = FLAGS.bands n_bands = len(bands) if n_bands <= 0: raise ValueError("You must provide some bands.") # Load upstream weights into encoder if checkpoint == "imagenet": if n_bands != 3: raise ValueError( "If initializing an encoder with ImageNet weights, you" "must provide exactly three bands." ) n_input_bands = 3 encoder = resnet_encoder(3, weights="imagenet") elif checkpoint == "random": n_input_bands = gf.n_bands() encoder = resnet_encoder(n_input_bands, weights=None) else: n_input_bands = gf.n_bands() encoder = resnet_encoder(n_input_bands) ckpt = tf.train.Checkpoint(encoder=encoder) ckpt_path = tf.train.latest_checkpoint(checkpoint) or checkpoint ckpt_restore_status = ckpt.restore(ckpt_path) if assert_checkpoint: ckpt_restore_status.assert_nontrivial_match().expect_partial() else: ckpt_restore_status.expect_partial() encoder.trainable = trainable model_inputs = Input(batch_shape=(batch_size, size, size, n_bands)) # Reorganize the present inputs according to the order given to_concat = list() if n_input_bands == 3: # Order RGB bands correctly for ImageNet experiments for rgb_band in ("red", "green", "blue"): for band_i, band in enumerate(bands): if band.endswith(rgb_band): to_concat.append(K.expand_dims(model_inputs[..., band_i], axis=-1)) break else: to_concat.append(K.zeros(shape=(batch_size, size, size, 1))) else: for default_band in FLAGS.bands: for band_i, band in enumerate(bands): if band == default_band: to_concat.append(K.expand_dims(model_inputs[..., band_i], axis=-1)) break else: to_concat.append(K.zeros(shape=(batch_size, size, size, 1))) all_inputs = Concatenate(axis=-1)(to_concat) # Multiply inputs according to missing bands scaled_inputs = Lambda(lambda x: x * input_scaling)(all_inputs) encoded = encoder(scaled_inputs) return model_inputs, scaled_inputs, encoded
def run_unsupervised_training(): """ Perform a full unsupervised training run programmatically. """ logging.info( "Starting unsupervised training run with flags:\n{}".format( FLAGS.flags_into_string() ) ) csf.distribution.initialize() data_shape = csf.data.data_shape() layers_and_weights = layer_loss_weights().items() logging.debug("Building dataset.") with csf.distribution.tpu_worker_context(): dataset = csf.distribution.distribute_dataset_fn(csf.data.load_dataset) dataset_iterator = iter(dataset) with csf.distribution.tpu_worker_context(), csf.distribution.distributed_context(): tf.random.set_seed(FLAGS.random_seed) logging.debug("Creating schedules.") learning_rate = tf.Variable( 0.0, trainable=False, name="learning_rate", dtype=tf.dtypes.float32, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) dropout_rate = tf.Variable( 0.0, trainable=False, name="dropout_rate_rate", dtype=tf.dtypes.float32, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) logging.debug("Building model and optimizer.") encoder = resnet_encoder(gf.n_bands()) encoder_vars = encoder.trainable_variables optimizer = tf.optimizers.Adam(learning_rate, clipnorm=FLAGS.gradient_clipnorm) logging.debug("Building checkpoint objects.") ckpt = tf.train.Checkpoint(encoder=encoder, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager( ckpt, FLAGS.out_dir, FLAGS.max_checkpoints, FLAGS.keep_checkpoint_every_n_hours, ) restore_ckpt = FLAGS.initial_checkpoint or ckpt_manager.latest_checkpoint if restore_ckpt: logging.info("Continuing training from checkpoint: {}".format(restore_ckpt)) ckpt.restore(restore_ckpt) else: logging.info("Initializing encoder and optimizer from scratch.") logging.debug("Building metrics.") summary_writer = tf.summary.create_file_writer(FLAGS.out_dir) loss_metric = tf.metrics.Mean("loss", tf.dtypes.float32) layer_loss_metrics = { layer: tf.metrics.Mean("{}_loss".format(layer), tf.dtypes.float32) for layer in layer_loss_weights().keys() } layer_accuracy_metrics = { layer: tf.metrics.Mean("{}_accuracy".format(layer), tf.dtypes.float32) for layer in layer_loss_weights().keys() } layer_rep_scale_metrics = { layer: tf.metrics.Mean( "{}_representation_size".format(layer), tf.dtypes.float32 ) for layer in layer_loss_weights().keys() } all_metrics = ( [loss_metric] + list(layer_loss_metrics.values()) + list(layer_accuracy_metrics.values()) + list(layer_rep_scale_metrics.values()) ) def write_metrics(step): with summary_writer.as_default(): tf.summary.scalar("learning_rate", learning_rate, step) tf.summary.scalar("band_dropout_rate", dropout_rate, step) for metric in all_metrics: tf.summary.scalar(metric.name, metric.result(), step) metric.reset_states() summary_writer.flush() logging.debug("Building distributed execution functions.") @csf.distribution.distribute_computation def _replicated_training_step(batch): batch = tf.reshape(batch, data_shape) with tf.name_scope("training_step"): with tf.name_scope("view_1"): view_1 = _create_view(batch, dropout_rate, seed=1) with tf.name_scope("view_2"): view_2 = _create_view(batch, dropout_rate, seed=2) losses = [] with tf.GradientTape() as tape: representations_1 = encoder(view_1) representations_2 = encoder(view_2) for layer, weight in layers_and_weights: with tf.name_scope("layer_{}".format(layer)): rep_1 = representations_1[layer] rep_2 = representations_2[layer] loss, accuracy = _contrastive_loss(rep_1, rep_2) # Plot the average 2-norm of representations with tf.name_scope("compute_scale"): rep_flat = tf.reshape( rep_1, (csf.distribution.replica_batch_size(), -1) ) rep_norms = tf.norm(rep_flat, axis=-1) layer_rep_scale_metrics[layer].update_state(rep_norms) losses.append(weight * loss) layer_loss_metrics[layer].update_state(loss) layer_accuracy_metrics[layer].update_state(accuracy) loss_total = tf.reduce_sum(losses, name="loss_total") gradients = tape.gradient(loss_total, encoder_vars) optimizer.apply_gradients(zip(gradients, encoder_vars)) loss_metric.update_state(loss_total) @tf.function def train_steps(iter_, steps): for _ in tf.range(steps): _replicated_training_step((next(iter_),)) logging.info("Beginning unsupervised training.") while True: step = optimizer.iterations.numpy() logging.info("Starting step: {}".format(step)) learning_rate.assign( csf.utils.optional_warmup( step, FLAGS.learning_rate, FLAGS.learning_rate_warmup_batches ) ) dropout_rate.assign( csf.utils.optional_warmup( step, FLAGS.band_dropout_rate, FLAGS.band_dropout_rate_warmup_batches, ) ) train_steps( dataset_iterator, tf.convert_to_tensor(FLAGS.callback_frequency, dtype=tf.int32), ) write_metrics(step) if (step // FLAGS.callback_frequency) % FLAGS.checkpoint_frequency == 0: ckpt_path = ckpt_manager.save() logging.info("Saved checkpoint at path: {}.".format(ckpt_path)) if FLAGS.max_batches and step >= FLAGS.max_batches: break ckpt_path = ckpt_manager.save() logging.info("Done with unsupervised training.") logging.info("Saving final checkpoint at path: {}.".format(ckpt_path))
def load_dataset(input_context=None): """ Load a dataset suitable for unsupervised training in a distributed environment: - If `input_context` is provided, it is used to build a single-replica dataset. - If `input_context` is not provided, builds a cross-replica dataset. In some cases (e.g. multi-GPU) a single dataset can be copied to multiple devices, so a cross-replica dataset should be built. This is also the default for single-device training. In other cases (e.g. TPU pods or clusters) each device must build its own single-replica dataset by calling this function. For most purposes it suffices to call this function with `input_context=None`. Do not manually pass an input context; instead create per-replica datasets with `tf.distribute.experimental_distribute_datasets_from_function`. Parameters ---------- input_context : None or tf.distribute.InputContext, optional If None, build a cross-replica, whole-dataset pipeline, as normal. If provided, shard the dataset to a single replica device. Returns ------- tf.data.Dataset A dataset of batched examples, where each example is a set of coterminous input bands. Examples are flattened for efficient communication with hardware. """ image_dims = (FLAGS.data_tilesize**2) * gf.n_bands() # NOTE: Make sure to test and fine-tune these optimizations for any new hardware. options = tf.data.Options() options.experimental_deterministic = False if FLAGS.enable_experimental_optimization: options.experimental_optimization.autotune_buffers = True options.experimental_optimization.autotune_cpu_budget = True options.experimental_optimization.parallel_batch = True options.experimental_optimization.map_fusion = True options.experimental_optimization.map_vectorization.enabled = False # If input_context is not provided, the dataset will be built once and copied to # each replica. Sharding and distribution will be handled automatically by the # distribution strategy, so we batch by global batch size. if input_context is not None and input_context.num_input_pipelines > 1: shard_data = True batch_size = input_context.get_per_replica_batch_size(FLAGS.batch_size) # Otherwise, _this function_ is copied to each replica and called once independently # per copy. The function should return an appropriate dataset for the replica -- # manually sharded to be non-overlapping with other replicas and batched to # the replica batch size. else: shard_data = False batch_size = FLAGS.batch_size if FLAGS.data_file: dataset = tf.data.TFRecordDataset( FLAGS.data_file).with_options(options) if shard_data: # Shard granularity: lines of tfrecords dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) dataset = dataset.repeat() else: dataset = (tf.data.experimental.CsvDataset( FLAGS.data_listing, [tf.dtypes.string]).with_options(options).map( lambda x: tf.reshape(x, []), num_parallel_calls=tf.data.experimental.AUTOTUNE, )) if shard_data: # Shard granularity: full tfrecord files dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) dataset = dataset.repeat().interleave( tf.data.TFRecordDataset, cycle_length=FLAGS.tfrecord_parallel_reads, block_length=FLAGS.tfrecord_sequential_reads, num_parallel_calls=tf.data.experimental.AUTOTUNE, ) feature_spec = { FLAGS.data_feature_name: tf.io.FixedLenFeature([], tf.dtypes.string) } def preprocess(batch): batch = tf.io.parse_example(batch, feature_spec) batch = tf.io.decode_raw(batch[FLAGS.data_feature_name], tf.dtypes.uint8, fixed_length=image_dims) batch = tf.cast(batch, tf.dtypes.float32) batch = (batch - 128.0) / 128.0 return batch def augment(batch): which_aug = tf.random.uniform(shape=(), dtype=tf.dtypes.int32, minval=0, maxval=8) aug_options = { 0: lambda: batch, 1: lambda: tf.transpose(batch, perm=[0, 2, 1, 3]), 2: lambda: tf.image.flip_up_down(batch), 3: lambda: tf.image.flip_left_right(tf.image.rot90(batch, k=1)), 4: lambda: tf.image.flip_left_right(batch), 5: lambda: tf.image.rot90(batch, k=1), 6: lambda: tf.image.rot90(batch, k=2), 7: lambda: tf.image.rot90(batch, k=3), } batch = tf.switch_case(which_aug, aug_options) return batch if FLAGS.shuffle_buffer_size > 0: dataset = dataset.shuffle( FLAGS.shuffle_buffer_size, reshuffle_each_iteration=True, seed=FLAGS.random_seed, ) # NOTE: examples are provided "flattened" and must be reshaped into images dataset = dataset.batch(batch_size, drop_remainder=True).map( preprocess, tf.data.experimental.AUTOTUNE) if FLAGS.enable_augmentation: dataset = dataset.map(augment, tf.data.experimental.AUTOTUNE) return dataset.prefetch(FLAGS.prefetch_batches)