Example #1
0
def train_step(strategy: tf.distribute.Strategy, data_it, disc: Model,
               gen: Model, model_g: Model, model_d: Model, batch_size: int,
               z_size: int, num_cat: int, metrics: Dict[str,
                                                        keras.metrics.Mean]):
    # Discriminate
    def discriminate(batch_images: tf.Tensor):
        train_vars = disc.trainable_variables

        batch_size = batch_images.shape[0]

        eps = tf.random.uniform((batch_size, 1, 1, 1), 0, 1)
        z_input, _, cat_input = CqGAN.generate_z(batch_size, z_size, num_cat)

        with tf.GradientTape() as tape:
            disc_gen, disc_real, iwgan_loss, cat_output = model_d(
                (z_input, batch_images, eps), training=True)
            full_loss = -disc_gen + disc_real + iwgan_loss + CqGAN.get_loss_cat(
                cat_input, cat_output)

        grads = tape.gradient(full_loss, train_vars)
        disc.optimizer.apply_gradients(zip(grads, train_vars))

        loss_real = full_loss

        metrics["disc_gen"].update_state(disc_gen)
        metrics["disc_real"].update_state(disc_real)
        metrics["loss_real"].update_state(loss_real)
        metrics["iwgan_loss"].update_state(iwgan_loss)

    def generate():
        train_vars = gen.trainable_variables

        z_input, _, cat_input = CqGAN.generate_z(batch_size, z_size, num_cat)

        with tf.GradientTape() as tape:
            disc_gen, cat_output = model_g(z_input, training=True)

            loss_gen = disc_gen
            loss_cat = CqGAN.get_loss_cat(cat_input, cat_output)

            full_loss = loss_gen + loss_cat

        grads = tape.gradient(full_loss, train_vars)
        gen.optimizer.apply_gradients(zip(grads, train_vars))

        metrics["loss_gen"].update_state(loss_gen)
        metrics["loss_cat"].update_state(loss_cat)

    for _ in range(3):
        batch_images = next(data_it)
        strategy.run(discriminate, args=(batch_images, ))

    strategy.run(generate)
def compute_predictions(
    model: PredictionModel, dataset: tf.data.Dataset,
    strategy: tf.distribute.Strategy, batch_size: int
) -> Iterator[Tuple[types.ModelPredictions, types.Features]]:
  """Yield the predictions of the model on the given dataset.

  Args:
    model: A function that takes tensor-valued features and returns a vector of
      predictions.
    dataset: The dataset that the function consumes to produce the predictions.
    strategy: The distribution strategy to use when computing.
    batch_size: The batch size that should be used.

  Yields:
    Pairs of model predictions and the corresponding metadata.
  """
  with strategy.scope():
    dataset = dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = (
        tf.data.experimental.AutoShardPolicy.DATA)
    dataset = dataset.with_options(options)

  for features in strategy.experimental_distribute_dataset(dataset):
    time_start = time.time()
    if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
      # TODO(josipd): Figure this out better. We can't easily filter,
      #               as they are PerReplica values, not tensors.
      features_model = {"image": features["image"]}
    else:
      features_model = features
    predictions = materialize(strategy,
                              strategy.run(model, args=(features_model,)))
    time_end = time.time()
    time_delta_per_example = (time_end - time_start) / predictions.shape[0]
    metadatas = materialize(strategy, features["metadata"])
    for i in range(predictions.shape[0]):
      model_predictions = types.ModelPredictions(
          predictions=[predictions[i]],
          time_in_s=time_delta_per_example)
      metadata_i = _slice_dictionary(metadatas, i)
      yield model_predictions, metadata_i
def compute_predictions(
        model: PredictionModel, dataset: tf.data.Dataset,
        strategy: tf.distribute.Strategy) -> Iterator[types.ModelPredictions]:
    """Yield the predictions of the model on the given dataset.

  Note that the dataset is expected to yield batches of tensors.

  Args:
    model: A function that takes tensor-valued features and returns a vector of
      predictions.
    dataset: The dataset that the function consumes to produce the predictions.
    strategy: The distribution strategy to use when computing.

  Yields:
    The predictions of the model on the dataset.
  """

    for features in strategy.experimental_distribute_dataset(dataset):
        # TODO(josipd): Figure out how to pass only tpu-allowed types.
        time_start = time.time()
        predictions = materialize(
            strategy, strategy.run(model,
                                   args=({
                                       "image": features["image"]
                                   }, )))
        time_end = time.time()
        time_delta_per_example = (time_end - time_start) / predictions.shape[0]
        try:
            element_ids = materialize(strategy, features["element_id"])
        except KeyError:
            element_ids = [None] * predictions.shape[0]
        metadatas = materialize(strategy, features["metadata"])
        for i in range(predictions.shape[0]):
            yield types.ModelPredictions(element_id=element_ids[i],
                                         metadata=_slice_dictionary(
                                             metadatas, i),
                                         predictions=[predictions[i]],
                                         time_in_s=time_delta_per_example)
Example #4
0
def predict_image(strategy: tf.distribute.Strategy, gen: Model,
                  input_z: tf.Tensor):
    return strategy.run(lambda: gen(input_z))