Beispiel #1
0
 def bottom(self, x):
     with tf.variable_scope(self.name):
         if not tf.contrib.eager.in_eager_mode():
             tf.summary.image("inputs",
                              common_layers.tpu_safe_image_summary(x),
                              max_outputs=2)
         return tf.to_float(x)
Beispiel #2
0
 def targets_bottom(self, x):
   inputs = x
   with tf.variable_scope(self.name):
     if not tf.contrib.eager.in_eager_mode():
       tf.summary.image(
           "targets_bottom",
           common_layers.tpu_safe_image_summary(inputs),
           max_outputs=1)
     inputs_shape = common_layers.shape_list(inputs)
     if len(inputs_shape) != 4:
       raise ValueError("Assuming images given as int tensors in the format "
                        "[batch, height, width, channels] (256 values).")
     # We embed each of 256=self.top_dimensionality possible pixel values.
     embedding_var = tf.get_variable(
         "pixel_embedding",
         [self.top_dimensionality, self.PIXEL_EMBEDDING_SIZE])
     hot_inputs = tf.one_hot(tf.to_int32(inputs), self.top_dimensionality)
     hot_inputs = tf.reshape(hot_inputs, [-1, self.top_dimensionality])
     embedded = tf.matmul(hot_inputs, embedding_var)
     # Let's now merge all channels that were embedded into a single vector.
     merged_size = self.PIXEL_EMBEDDING_SIZE * inputs_shape[3]
     embedded = tf.reshape(embedded, inputs_shape[:3] + [merged_size])
     merged = tf.layers.dense(
         embedded,
         self._body_input_depth,
         name="merge_pixel_embedded_channels")
     return merged
Beispiel #3
0
 def targets_bottom(self, x):
     inputs = x
     with tf.variable_scope(self.name):
         if not tf.contrib.eager.in_eager_mode():
             tf.summary.image("targets_bottom",
                              common_layers.tpu_safe_image_summary(inputs),
                              max_outputs=1)
         inputs_shape = common_layers.shape_list(inputs)
         if len(inputs_shape) != 4:
             raise ValueError(
                 "Assuming images given as int tensors in the format "
                 "[batch, height, width, channels] (256 values).")
         # We embed each of 256=self.top_dimensionality possible pixel values.
         embedding_var = tf.get_variable(
             "pixel_embedding",
             [self.top_dimensionality, self.PIXEL_EMBEDDING_SIZE])
         hot_inputs = tf.one_hot(tf.to_int32(inputs),
                                 self.top_dimensionality)
         hot_inputs = tf.reshape(hot_inputs, [-1, self.top_dimensionality])
         embedded = tf.matmul(hot_inputs, embedding_var)
         # Let's now merge all channels that were embedded into a single vector.
         merged_size = self.PIXEL_EMBEDDING_SIZE * inputs_shape[3]
         embedded = tf.reshape(embedded, inputs_shape[:3] + [merged_size])
         merged = tf.layers.dense(embedded,
                                  self._body_input_depth,
                                  name="merge_pixel_embedded_channels")
         return merged
Beispiel #4
0
 def bottom(self, x):
     with tf.variable_scope(self.name):
         if not tf.executing_eagerly():
             tf.summary.image("inputs",
                              common_layers.tpu_safe_image_summary(x),
                              max_outputs=2)
         return tf.to_float(x)
Beispiel #5
0
 def image_summary(self, name, image_logits, max_outputs=1):
   """Helper for image summaries that are safe on TPU."""
   if len(image_logits.get_shape()) != 5:
     tf.logging.info("Not generating image summary, maybe not an image.")
     return
   return tf.summary.image(
       name,
       common_layers.tpu_safe_image_summary(tf.argmax(image_logits, -1)),
       max_outputs=max_outputs)
Beispiel #6
0
 def image_summary(self, name, image_logits, max_outputs=1):
   """Helper for image summaries that are safe on TPU."""
   if len(image_logits.get_shape()) != 5:
     tf.logging.info("Not generating image summary, maybe not an image.")
     return
   return tf.summary.image(
       name,
       common_layers.tpu_safe_image_summary(tf.argmax(image_logits, -1)),
       max_outputs=max_outputs)
 def top(self, body_output, _):
   num_channels = self._model_hparams.problem.num_channels
   num_frames = self._model_hparams.video_num_target_frames
   with tf.variable_scope("rgb"):
     body_output_shape = common_layers.shape_list(body_output)
     res = tf.layers.dense(body_output, num_channels * num_frames, name="cast")
     res = tf.reshape(res, body_output_shape[:3] + [num_channels, num_frames])
     res = tf.transpose(res, [0, 4, 1, 2, 3])  # Move frames next to batch.
     if not tf.get_variable_scope().reuse:
       res_argmax = res[:, -1, :, :, :]
       tf.summary.image(
           "result",
           common_layers.tpu_safe_image_summary(res_argmax),
           max_outputs=1)
     return tf.expand_dims(res, axis=-1)  # Add an axis like in perplexity.
Beispiel #8
0
 def top(self, body_output, _):
   num_channels = self._model_hparams.problem.num_channels
   num_frames = self._model_hparams.video_num_target_frames
   with tf.variable_scope("rgb"):
     body_output_shape = common_layers.shape_list(body_output)
     res = tf.layers.dense(body_output, num_channels * num_frames, name="cast")
     res = tf.reshape(res, body_output_shape[:3] + [num_channels, num_frames])
     res = tf.transpose(res, [0, 4, 1, 2, 3])  # Move frames next to batch.
     if not tf.get_variable_scope().reuse:
       res_argmax = res[:, -1, :, :, :]
       tf.summary.image(
           "result",
           common_layers.tpu_safe_image_summary(res_argmax),
           max_outputs=1)
     return tf.expand_dims(res, axis=-1)  # Add an axis like in perplexity.
 def top(self, body_output, _):
   # TODO(lukaszkaiser): is this a universal enough way to get channels?
   num_channels = self._model_hparams.problem.num_channels
   with tf.variable_scope("rgb_softmax"):
     body_output_shape = common_layers.shape_list(body_output)
     reshape_shape = body_output_shape[:3]
     reshape_shape.extend([num_channels, self.top_dimensionality])
     res = tf.layers.dense(body_output, self.top_dimensionality * num_channels)
     res = tf.reshape(res, reshape_shape)
     if not tf.get_variable_scope().reuse:
       res_argmax = tf.argmax(res, axis=-1)
       tf.summary.image(
           "result",
           common_layers.tpu_safe_image_summary(res_argmax),
           max_outputs=1)
     return res
Beispiel #10
0
 def top(self, body_output, _):
   # TODO(lukaszkaiser): is this a universal enough way to get channels?
   num_channels = self._model_hparams.problem.num_channels
   with tf.variable_scope("rgb_softmax"):
     body_output_shape = common_layers.shape_list(body_output)
     reshape_shape = body_output_shape[:3]
     reshape_shape.extend([num_channels, self.top_dimensionality])
     res = tf.layers.dense(body_output, self.top_dimensionality * num_channels)
     res = tf.reshape(res, reshape_shape)
     if not tf.get_variable_scope().reuse:
       res_argmax = tf.argmax(res, axis=-1)
       tf.summary.image(
           "result",
           common_layers.tpu_safe_image_summary(res_argmax),
           max_outputs=1)
     return res
Beispiel #11
0
 def top(self, body_output, _):
   num_channels = self._model_hparams.problem.num_channels
   num_frames = self._model_hparams.video_num_target_frames
   with tf.variable_scope("rgb_softmax"):
     body_output_shape = common_layers.shape_list(body_output)
     reshape_shape = body_output_shape[:3]
     reshape_shape.extend([num_channels, num_frames, self.top_dimensionality])
     res = tf.layers.dense(body_output,
                           self.top_dimensionality * num_channels * num_frames)
     res = tf.reshape(res, reshape_shape)
     res = tf.transpose(res, [0, 4, 1, 2, 3, 5])
     if not tf.get_variable_scope().reuse:
       res_argmax = tf.argmax(res[:, -1, :, :, :, :], axis=-1)
       tf.summary.image(
           "result",
           common_layers.tpu_safe_image_summary(res_argmax),
           max_outputs=1)
     return res
Beispiel #12
0
 def top(self, body_output, _):
   num_channels = self._model_hparams.problem.num_channels
   num_frames = self._model_hparams.video_num_target_frames
   with tf.variable_scope("rgb_softmax"):
     body_output_shape = common_layers.shape_list(body_output)
     reshape_shape = body_output_shape[:3]
     reshape_shape.extend([num_channels, num_frames, self.top_dimensionality])
     res = tf.layers.dense(body_output,
                           self.top_dimensionality * num_channels * num_frames)
     res = tf.reshape(res, reshape_shape)
     res = tf.transpose(res, [0, 4, 1, 2, 3, 5])
     if not tf.get_variable_scope().reuse:
       res_argmax = tf.argmax(res[:, -1, :, :, :, :], axis=-1)
       tf.summary.image(
           "result",
           common_layers.tpu_safe_image_summary(res_argmax),
           max_outputs=1)
     return res
Beispiel #13
0
    def bottom_compress(self, inputs, name="bottom"):
        """Compresses channel-wise input pixels into whole pixel representions.

    Perform conversion of RGB pixel values to a real number in the range -1 to
    1. This combines pixel channels to form a representation of shape
    [img_len, img_len].

    Args:
      inputs: Tensor representing RGB pixel intensities as integers, of shape
        [batch, img_len, img_len, channels].
      name: string, scope.

    Returns:
      body_input: Tensor of shape [batch, img_len, img_len, body_input_depth].
    """
        with tf.variable_scope(name):
            inputs = tf.to_float(inputs)
            hp = self._model_hparams
            if hp.mode != tf.estimator.ModeKeys.PREDICT:
                tf.summary.image("inputs",
                                 common_layers.tpu_safe_image_summary(inputs),
                                 max_outputs=2)
            inputs = common_layers.convert_rgb_to_symmetric_real(inputs)

            # Reshape inputs to apply convolutions across [img_len, img_len*channels].
            inputs_shape = common_layers.shape_list(inputs)
            inputs = tf.reshape(
                inputs,
                [-1, inputs_shape[1], inputs_shape[2] * inputs_shape[3], 1])
            # tf.logging.info("input shape" , inputs_shape)
            # Compress RGB intensities for each pixel using a convolution.
            # ValueError: Negative dimension size caused by subtracting 3 from 1 for 'imagetransformer/parallel_0_4/
            # imagetransformer/imagetransformer/image_channel_bottom_identity_modality/output_bottom/conv_input/Conv2D'
            #  (op: 'Conv2D') with input shapes: [?,1,1,1], [1,3,1,256].
            outputs = tf.layers.conv2d(inputs,
                                       self._body_input_depth,
                                       kernel_size=(1, self.num_channels),
                                       padding="VALID",
                                       strides=(1, self.num_channels),
                                       activation=tf.nn.relu,
                                       name="conv_input")
            return outputs
Beispiel #14
0
  def bottom_compress(self, inputs, name="bottom"):
    """Compresses channel-wise input pixels into whole pixel representions.

    Perform conversion of RGB pixel values to a real number in the range -1 to
    1. This combines pixel channels to form a representation of shape
    [img_len, img_len].

    Args:
      inputs: Tensor representing RGB pixel intensities as integers, of shape
        [batch, img_len, img_len, channels].
      name: string, scope.

    Returns:
      body_input: Tensor of shape
        [batch, img_len, img_len, self._model_hparams.hidden_size].
    """
    with tf.variable_scope(name):
      inputs = tf.to_float(inputs)
      hp = self._model_hparams
      if hp.mode != tf.estimator.ModeKeys.PREDICT:
        tf.summary.image(
            "inputs",
            common_layers.tpu_safe_image_summary(inputs),
            max_outputs=2)
      inputs = common_layers.convert_rgb_to_symmetric_real(inputs)

      # Reshape inputs to apply convolutions across [img_len, img_len*channels].
      inputs_shape = common_layers.shape_list(inputs)
      inputs = tf.reshape(
          inputs, [-1, inputs_shape[1], inputs_shape[2] * inputs_shape[3], 1])

      # Compress RGB intensities for each pixel using a convolution.
      outputs = tf.layers.conv2d(
          inputs,
          self._model_hparams.hidden_size,
          kernel_size=(1, self.num_channels),
          padding="VALID",
          strides=(1, self.num_channels),
          activation=tf.nn.relu,
          name="conv_input")
      return outputs
Beispiel #15
0
    def bottom_compress(self, inputs, name="bottom"):
        """Transform input from data space to model space.

    Perform conversion of RGB pixel values to a real number in the range -1 to 1
    and combine channel values for each pixel to form a representation of
    size image_length x image_length dims.

    Args:
      inputs: A Tensor representing RGB pixel intensities as integers.
        [batch, ...]
      name: string, scope.
    Returns:
      body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
    """
        with tf.variable_scope(name):
            inputs = tf.to_float(inputs)
            hp = self._model_hparams
            if hp.mode != tf.estimator.ModeKeys.PREDICT:
                tf.summary.image("inputs",
                                 common_layers.tpu_safe_image_summary(inputs),
                                 max_outputs=2)
            inputs = common_layers.convert_rgb_to_symmetric_real(inputs)
            ishape = common_layers.shape_list(inputs)
            inputs = tf.reshape(inputs,
                                [-1, ishape[1], ishape[2] * ishape[3], 1])
            inputs.set_shape([None, None, None, 1])
            # We compress RGB intensities for each pixel using a conv.
            x = tf.layers.conv2d(inputs,
                                 self._body_input_depth,
                                 (1, self.num_channels),
                                 padding="VALID",
                                 strides=(1, self.num_channels),
                                 activation=tf.nn.relu,
                                 name="conv_input")
            x.set_shape([None, None, None, self._body_input_depth])
            return x
Beispiel #16
0
  def bottom_compress(self, inputs, name="bottom"):
    """Transform input from data space to model space.

    Perform conversion of RGB pixel values to a real number in the range -1 to 1
    and combine channel values for each pixel to form a representation of
    size image_length x image_length dims.

    Args:
      inputs: A Tensor representing RGB pixel intensities as integers.
        [batch, ...]
      name: string, scope.
    Returns:
      body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
    """
    with tf.variable_scope(name):
      inputs = tf.to_float(inputs)
      hp = self._model_hparams
      if hp.mode != tf.estimator.ModeKeys.PREDICT:
        tf.summary.image(
            "inputs",
            common_layers.tpu_safe_image_summary(inputs),
            max_outputs=2)
      inputs = common_layers.convert_rgb_to_symmetric_real(inputs)
      ishape = common_layers.shape_list(inputs)
      inputs = tf.reshape(inputs, [-1, ishape[1], ishape[2] * ishape[3], 1])
      inputs.set_shape([None, None, None, 1])
      # We compress RGB intensities for each pixel using a conv.
      x = tf.layers.conv2d(
          inputs,
          self._body_input_depth, (1, self.num_channels),
          padding="VALID",
          strides=(1, self.num_channels),
          activation=tf.nn.relu,
          name="conv_input")
      x.set_shape([None, None, None, self._body_input_depth])
      return x
Beispiel #17
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      x = features["targets"]
      shape = common_layers.shape_list(x)
      is1d = shape[2] == 1
      self.is1d = is1d
      # Run encoder.
      x = self.encoder(x)
      # Bottleneck (mix during early training, not too important but stable).
      b, b_loss = self.bottleneck(x)
      self._cur_bottleneck_tensor = b
      b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
      b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(), common_layers.shape_list(x)[-1], reuse=True)
        b = tf.concat([g, b], axis=0)
      # With probability bottleneck_max_prob use the bottleneck, otherwise x.
      if hparams.bottleneck_max_prob < -1.0:
        x = tf.where(
            tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x)
      else:
        x = b
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      return x, {"bottleneck_loss": 0.0}
    # Cut to the right size and mix before returning.
    res = x[:, :shape[1], :shape[2], :]
    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      # Split back if we added a purely sampled batch.
      res_gan, res = tf.split(res, 2, axis=0)
      num_channels = self.hparams.problem.num_channels
      res_rgb = common_layers.convert_real_to_rgb(
          tf.nn.sigmoid(tf.layers.dense(res_gan, num_channels, name="gan_rgb")))
      tf.summary.image(
          "gan", common_layers.tpu_safe_image_summary(res_rgb), max_outputs=1)
      orig_rgb = tf.to_float(features["targets_raw"])

      def discriminate(x):
        return self.discriminator(x, is_training=is_training)

      gan_loss = common_layers.sliced_gan_loss(orig_rgb,
                                               reverse_gradient(res_rgb),
                                               discriminate,
                                               self.hparams.num_sliced_vecs)
      gan_loss *= hparams.gan_loss_factor
    # Mix the final result and return.
    res = common_layers.mix(res, features["targets"],
                            hparams.bottleneck_warmup_steps // 2, is_training)
    return res, {"bottleneck_loss": b_loss, "gan_loss": -gan_loss}
Beispiel #18
0
 def bottom(self, x):
   with tf.variable_scope(self.name):
     if not tf.contrib.eager.in_eager_mode():
       tf.summary.image(
           "inputs", common_layers.tpu_safe_image_summary(x), max_outputs=2)
     return tf.to_float(x)
Beispiel #19
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      x = features["targets"]
      labels = features["targets_raw"]
      shape = common_layers.shape_list(x)
      is1d = shape[2] == 1
      self.is1d = is1d
      # Run encoder.
      x = self.encoder(x)
      # Bottleneck (mix during early training, not too important but stable).
      b, b_loss = self.bottleneck(x)
      self._cur_bottleneck_tensor = b
      b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
      b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(), common_layers.shape_list(x)[-1], reuse=True)
        b = tf.concat([g, b], axis=0)
      # With probability bottleneck_max_prob use the bottleneck, otherwise x.
      if hparams.bottleneck_max_prob < -1.0:
        x = tf.where(
            tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x)
      else:
        x = b
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      return x, {"bottleneck_loss": 0.0}
    # Cut to the right size and mix before returning.
    res = x[:, :shape[1], :shape[2], :]

    is_image = isinstance(self.hparams.problem, image_utils.ImageProblem)
    if is_image:
      vocab_size = self.hparams.problem.vocab_size

      res = tf.layers.dense(
          res, self.hparams.problem.num_channels * self.hparams.hidden_size)
      output_shape = common_layers.shape_list(res)[:-1] + [
          self.hparams.problem.num_channels, self.hparams.hidden_size
      ]
      res = tf.reshape(res, output_shape)
    elif isinstance(self.hparams.problem, text_problems.Text2TextProblem):
      vocab_size = self._problem_hparams.target_modality.top_dimensionality
      res = tf.layers.dense(res, self.hparams.hidden_size)
    else:
      raise Exception("Unsupported problem type: %s" % self.hparams.problem)

    one_hot_labels = tf.one_hot(labels, vocab_size)
    code_loss_gan = 0.0
    if hparams.gan_loss_factor != 0.0:
      res_gan, res = tf.split(res, 2, axis=0)
      with tf.variable_scope("vq"):
        reconstr_gan, _, code_loss_gan, _ = discretization.vq_loss(
            res, one_hot_labels, vocab_size)

    with tf.variable_scope("vq", reuse=tf.AUTO_REUSE):
      reconstr, target_codes, code_loss, targets_loss = discretization.vq_loss(
          res, one_hot_labels, vocab_size)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      if is_image:
        tf.summary.image(
            "gan",
            common_layers.tpu_safe_image_summary(tf.argmax(reconstr_gan, -1)),
            max_outputs=1)

      def discriminate(x):
        return self.discriminator(x, is_training=is_training)

      gan_loss = common_layers.sliced_gan_loss(target_codes,
                                               reverse_gradient(res_gan),
                                               discriminate,
                                               self.hparams.num_sliced_vecs)
      gan_loss *= hparams.gan_loss_factor

    if is_image:
      tf.summary.image(
          "ae",
          common_layers.tpu_safe_image_summary(tf.argmax(reconstr, -1)),
          max_outputs=1)

    return reconstr, {
        "training": targets_loss,
        "code_loss": code_loss,
        "code_loss_gan": code_loss_gan,
        "b_loss": b_loss,
        "gan_loss": -gan_loss
    }