示例#1
0
 def bottom(self, inputs):
     with tf.variable_scope(self.name):
         common_layers.summarize_video(inputs, "targets_bottom")
         # Embed bitwise.
         assert self.top_dimensionality == 256
         embedded = discretization.int_to_bit_embed(
             inputs, 8, self.PIXEL_EMBEDDING_SIZE)
         # Transpose and project.
         transposed = common_layers.time_to_channels(embedded)
         return tf.layers.dense(transposed,
                                self._body_input_depth,
                                name="merge_pixel_embedded_frames")
示例#2
0
 def targets_bottom(self, x):  # pylint: disable=arguments-differ
     inputs = x
     with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
         common_layers.summarize_video(inputs, "targets_bottom")
         # Embed bitwise.
         assert self.top_dimensionality == 256
         embedded = discretization.int_to_bit_embed(
             inputs, 8, self.PIXEL_EMBEDDING_SIZE)
         # Transpose and project.
         transposed = common_layers.time_to_channels(embedded)
         return tf.layers.dense(transposed,
                                self._model_hparams.hidden_size,
                                name="merge_pixel_embedded_frames")
示例#3
0
 def targets_bottom(self, x):  # pylint: disable=arguments-differ
   inputs = x
   with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
     common_layers.summarize_video(inputs, "targets_bottom")
     # Embed bitwise.
     assert self.top_dimensionality == 256
     embedded = discretization.int_to_bit_embed(inputs, 8,
                                                self.PIXEL_EMBEDDING_SIZE)
     # Transpose and project.
     transposed = common_layers.time_to_channels(embedded)
     return tf.layers.dense(
         transposed,
         self._body_input_depth,
         name="merge_pixel_embedded_frames")
示例#4
0
 def targets_bottom(self, x, summary_prefix="targets_bottom"):  # pylint: disable=arguments-differ
   inputs = x
   with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
     common_layers.summarize_video(inputs, summary_prefix)
     inputs_shape = common_layers.shape_list(inputs)
     # We embed each of 256=self.top_dimensionality possible pixel values.
     embedding_var = tf.get_variable(
         "pixel_embedding",
         [self.top_dimensionality, self.PIXEL_EMBEDDING_SIZE])
     hot_inputs = tf.one_hot(tf.to_int32(inputs), self.top_dimensionality)
     hot_inputs = tf.reshape(hot_inputs, [-1, self.top_dimensionality])
     embedded = tf.matmul(hot_inputs, embedding_var)
     # Let's now merge all channels that were embedded into a single vector.
     merged_size = self.PIXEL_EMBEDDING_SIZE * inputs_shape[4]
     embedded = tf.reshape(embedded, inputs_shape[:4] + [merged_size])
     transposed = common_layers.time_to_channels(embedded)
     return tf.layers.dense(transposed, self._body_input_depth,
                            name="merge_pixel_embedded_frames")
示例#5
0
 def targets_bottom(self, x, summary_prefix="targets_bottom"):  # pylint: disable=arguments-differ
   inputs = x
   with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
     common_layers.summarize_video(inputs, summary_prefix)
     inputs_shape = common_layers.shape_list(inputs)
     # We embed each of 256=self.top_dimensionality possible pixel values.
     embedding_var = tf.get_variable(
         "pixel_embedding",
         [self.top_dimensionality, self.PIXEL_EMBEDDING_SIZE])
     hot_inputs = tf.one_hot(tf.to_int32(inputs), self.top_dimensionality)
     hot_inputs = tf.reshape(hot_inputs, [-1, self.top_dimensionality])
     embedded = tf.matmul(hot_inputs, embedding_var)
     # Let's now merge all channels that were embedded into a single vector.
     merged_size = self.PIXEL_EMBEDDING_SIZE * inputs_shape[4]
     embedded = tf.reshape(embedded, inputs_shape[:4] + [merged_size])
     transposed = common_layers.time_to_channels(embedded)
     return tf.layers.dense(
         transposed,
         self._body_input_depth,
         name="merge_pixel_embedded_frames")
 def bottom(self, x):
   inputs = x
   with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
     common_layers.summarize_video(inputs, "inputs")
     inputs = common_layers.standardize_images(inputs)
     return common_layers.time_to_channels(inputs)
示例#7
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.target_modality.top_dimensionality
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = common_layers.time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x_stop - b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([g, x], axis=0)
        encoder_layers = [tf.concat([l, l], axis=0) for l in encoder_layers]
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res_gan, res = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses