def train_step(strategy: tf.distribute.Strategy, data_it, disc: Model, gen: Model, model_g: Model, model_d: Model, batch_size: int, z_size: int, num_cat: int, metrics: Dict[str, keras.metrics.Mean]): # Discriminate def discriminate(batch_images: tf.Tensor): train_vars = disc.trainable_variables batch_size = batch_images.shape[0] eps = tf.random.uniform((batch_size, 1, 1, 1), 0, 1) z_input, _, cat_input = CqGAN.generate_z(batch_size, z_size, num_cat) with tf.GradientTape() as tape: disc_gen, disc_real, iwgan_loss, cat_output = model_d( (z_input, batch_images, eps), training=True) full_loss = -disc_gen + disc_real + iwgan_loss + CqGAN.get_loss_cat( cat_input, cat_output) grads = tape.gradient(full_loss, train_vars) disc.optimizer.apply_gradients(zip(grads, train_vars)) loss_real = full_loss metrics["disc_gen"].update_state(disc_gen) metrics["disc_real"].update_state(disc_real) metrics["loss_real"].update_state(loss_real) metrics["iwgan_loss"].update_state(iwgan_loss) def generate(): train_vars = gen.trainable_variables z_input, _, cat_input = CqGAN.generate_z(batch_size, z_size, num_cat) with tf.GradientTape() as tape: disc_gen, cat_output = model_g(z_input, training=True) loss_gen = disc_gen loss_cat = CqGAN.get_loss_cat(cat_input, cat_output) full_loss = loss_gen + loss_cat grads = tape.gradient(full_loss, train_vars) gen.optimizer.apply_gradients(zip(grads, train_vars)) metrics["loss_gen"].update_state(loss_gen) metrics["loss_cat"].update_state(loss_cat) for _ in range(3): batch_images = next(data_it) strategy.run(discriminate, args=(batch_images, )) strategy.run(generate)
def compute_predictions( model: PredictionModel, dataset: tf.data.Dataset, strategy: tf.distribute.Strategy, batch_size: int ) -> Iterator[Tuple[types.ModelPredictions, types.Features]]: """Yield the predictions of the model on the given dataset. Args: model: A function that takes tensor-valued features and returns a vector of predictions. dataset: The dataset that the function consumes to produce the predictions. strategy: The distribution strategy to use when computing. batch_size: The batch size that should be used. Yields: Pairs of model predictions and the corresponding metadata. """ with strategy.scope(): dataset = dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) options = tf.data.Options() options.experimental_distribute.auto_shard_policy = ( tf.data.experimental.AutoShardPolicy.DATA) dataset = dataset.with_options(options) for features in strategy.experimental_distribute_dataset(dataset): time_start = time.time() if isinstance(strategy, tf.distribute.experimental.TPUStrategy): # TODO(josipd): Figure this out better. We can't easily filter, # as they are PerReplica values, not tensors. features_model = {"image": features["image"]} else: features_model = features predictions = materialize(strategy, strategy.run(model, args=(features_model,))) time_end = time.time() time_delta_per_example = (time_end - time_start) / predictions.shape[0] metadatas = materialize(strategy, features["metadata"]) for i in range(predictions.shape[0]): model_predictions = types.ModelPredictions( predictions=[predictions[i]], time_in_s=time_delta_per_example) metadata_i = _slice_dictionary(metadatas, i) yield model_predictions, metadata_i
def compute_predictions( model: PredictionModel, dataset: tf.data.Dataset, strategy: tf.distribute.Strategy) -> Iterator[types.ModelPredictions]: """Yield the predictions of the model on the given dataset. Note that the dataset is expected to yield batches of tensors. Args: model: A function that takes tensor-valued features and returns a vector of predictions. dataset: The dataset that the function consumes to produce the predictions. strategy: The distribution strategy to use when computing. Yields: The predictions of the model on the dataset. """ for features in strategy.experimental_distribute_dataset(dataset): # TODO(josipd): Figure out how to pass only tpu-allowed types. time_start = time.time() predictions = materialize( strategy, strategy.run(model, args=({ "image": features["image"] }, ))) time_end = time.time() time_delta_per_example = (time_end - time_start) / predictions.shape[0] try: element_ids = materialize(strategy, features["element_id"]) except KeyError: element_ids = [None] * predictions.shape[0] metadatas = materialize(strategy, features["metadata"]) for i in range(predictions.shape[0]): yield types.ModelPredictions(element_id=element_ids[i], metadata=_slice_dictionary( metadatas, i), predictions=[predictions[i]], time_in_s=time_delta_per_example)
def predict_image(strategy: tf.distribute.Strategy, gen: Model, input_z: tf.Tensor): return strategy.run(lambda: gen(input_z))