def loss_iw(self, logits, features):
        if isinstance(logits, dict):
            losses = {}
            for k, v in six.iteritems(logits):
                losses[k] = self._loss_single_iw(v,
                                                 k,
                                                 features[k],
                                                 weights=features.get(k +
                                                                      "_mask"))

                n, d = losses[k]
                if common_layers.should_generate_summaries():
                    tf.summary.scalar(k + "_loss", n / d)
                    tf.summary.scalar(k + "_loss_num", n)
                    tf.summary.scalar(k + "_loss_den", d)
                    if getattr(self.hparams, "visualize_logits_histogram",
                               False):
                        hist = tf.summary.histogram
                        hist(k + "_predict", tf.argmax(tf.squeeze(v), axis=-1))
                        hist(k + "_targets", features[k])

            return tf.add_n([n / d for n, d in losses.values()])
        else:
            return self._loss_single_iw(logits,
                                        "targets",
                                        features["targets"],
                                        weights=features.get("targets_mask"))
예제 #2
0
def optimize(loss, learning_rate, hparams, use_tpu=False):
    """Minimize loss."""
    loss = weight_decay_and_noise(loss, hparams, learning_rate)
    loss = tf.identity(loss, name="total_loss")
    log_variable_sizes(verbose=hparams.summarize_vars)
    if hparams.summarize_vars:
        summarize_variables()
    diet_vars = [
        v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
    ]
    log_variable_sizes(diet_vars,
                       "Diet Variables",
                       verbose=hparams.summarize_vars)
    opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams,
                               use_tpu)
    if use_tpu:
        opt = tf.contrib.tpu.CrossShardOptimizer(opt)

    opt_summaries = []
    if common_layers.should_generate_summaries():
        tf.summary.scalar("learning_rate", learning_rate)
        opt_summaries = ["loss"]
    if hparams.summarize_grads and common_layers.should_generate_summaries():
        tf.logging.info("Summarizing gradients")
        opt_summaries.extend(
            ["gradients", "gradient_norm", "global_gradient_norm"])

    if hparams.clip_grad_norm:
        tf.logging.info("Clipping gradients, norm: %0.5f",
                        hparams.clip_grad_norm)
    if hparams.grad_noise_scale:
        tf.logging.info("Adding noise to gradients, noise scale: %0.5f",
                        hparams.grad_noise_scale)

    train_op = tf.contrib.layers.optimize_loss(
        name="training",
        loss=loss,
        global_step=tf.train.get_or_create_global_step(),
        learning_rate=learning_rate,
        clip_gradients=hparams.clip_grad_norm or None,
        gradient_noise_scale=hparams.grad_noise_scale or None,
        optimizer=opt,
        summaries=opt_summaries,
        colocate_gradients_with_ops=True)
    return train_op
예제 #3
0
 def decoder(self, x, encoder_layers=None):
   with tf.variable_scope("decoder"):
     hparams = self.hparams
     is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN
     kernel, strides = self._get_kernel_and_strides()
     residual_kernel = (hparams.residual_kernel_height,
                        hparams.residual_kernel_width)
     residual_kernel1d = (hparams.residual_kernel_height, 1)
     residual_kernel = residual_kernel1d if self.is1d else residual_kernel
     residual_conv = tf.layers.conv2d
     if hparams.residual_use_separable_conv:
       residual_conv = tf.layers.separable_conv2d
     # Up-convolutions.
     for i in range(hparams.num_hidden_layers):
       j = hparams.num_hidden_layers - i - 1
       if is_training:
         nomix_p = common_layers.inverse_lin_decay(
             int(hparams.bottleneck_warmup_steps * 0.25 * 2**j)) + 0.01
         if common_layers.should_generate_summaries():
           tf.summary.scalar("nomix_p_%d" % j, nomix_p)
       filters = hparams.hidden_size * 2**j
       filters = min(filters, hparams.max_hidden_size)
       with tf.variable_scope("layer_%d" % i):
         j = hparams.num_hidden_layers - i - 1
         x = tf.layers.conv2d_transpose(
             x,
             filters,
             kernel,
             strides=strides,
             padding="SAME",
             activation=common_layers.belu,
             name="strided")
         y = x
         for r in range(hparams.num_residual_layers):
           residual_filters = filters
           if r < hparams.num_residual_layers - 1:
             residual_filters = int(
                 filters * hparams.residual_filter_multiplier)
           y = residual_conv(
               y,
               residual_filters,
               residual_kernel,
               padding="SAME",
               activation=common_layers.belu,
               name="residual_%d" % r)
         x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout)
         x = common_layers.layer_norm(x, name="ln")
         x = common_attention.add_timing_signal_nd(x)
         if encoder_layers is not None:
           enc_x = encoder_layers[j]
           enc_shape = common_layers.shape_list(enc_x)
           x_mix = x[:enc_shape[0], :enc_shape[1], :enc_shape[2], :]
           if is_training:  # Mix at the beginning of training.
             rand = tf.random_uniform(common_layers.shape_list(x_mix))
             x_mix = tf.where(tf.less(rand, nomix_p), x_mix, enc_x)
           x = x_mix
     return x
예제 #4
0
def optimize(loss, learning_rate, hparams, use_tpu=False):
  """Minimize loss."""
  loss = weight_decay_and_noise(loss, hparams, learning_rate)
  loss = tf.identity(loss, name="total_loss")
  # Print trainable variables.
  log_variable_sizes(verbose=hparams.summarize_vars)
  # Print non-trainable variables.
  non_trainable_variables = list(
      set(tf.global_variables()) - set(tf.trainable_variables()))
  log_variable_sizes(non_trainable_variables, tag="Non-trainable variables",
                     verbose=hparams.summarize_vars)
  if hparams.summarize_vars:
    summarize_variables()
    # Summarize non-trainable variables as well
    summarize_variables(non_trainable_variables, tag="Non-trainable variables")
  diet_vars = [
      v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
  ]
  log_variable_sizes(
      diet_vars, "Diet Variables", verbose=hparams.summarize_vars)
  opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu)
  if use_tpu:
    opt = tf.contrib.tpu.CrossShardOptimizer(opt)

  opt_summaries = []
  if common_layers.should_generate_summaries():
    tf.summary.scalar("learning_rate", learning_rate)
    opt_summaries.append("loss")
    if hparams.summarize_grads:
      tf.logging.info("Summarizing gradients")
      opt_summaries.extend(
          ["gradients", "gradient_norm", "global_gradient_norm"])

  if hparams.clip_grad_norm:
    tf.logging.info("Clipping gradients, norm: %0.5f", hparams.clip_grad_norm)
  if hparams.grad_noise_scale:
    tf.logging.info("Adding noise to gradients, noise scale: %0.5f",
                    hparams.grad_noise_scale)

  train_op = tf.contrib.layers.optimize_loss(
      name="training",
      loss=loss,
      global_step=tf.train.get_or_create_global_step(),
      learning_rate=learning_rate,
      clip_gradients=hparams.clip_grad_norm or None,
      gradient_noise_scale=hparams.grad_noise_scale or None,
      optimizer=opt,
      summaries=opt_summaries,
      colocate_gradients_with_ops=True)
  return train_op
예제 #5
0
def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):
  """Apply weight decay and weight noise."""
  if var_list is None:
    var_list = tf.trainable_variables()

  decay_vars = [v for v in var_list]
  noise_vars = [v for v in var_list if "/body/" in v.name]

  weight_decay_loss = weight_decay(hparams.weight_decay, decay_vars)
  if hparams.weight_decay and common_layers.should_generate_summaries():
    tf.summary.scalar("losses/weight_decay", weight_decay_loss)
  weight_noise_ops = weight_noise(hparams.weight_noise, learning_rate,
                                  noise_vars)

  with tf.control_dependencies(weight_noise_ops):
    loss = tf.identity(loss)

  loss += weight_decay_loss
  return loss
예제 #6
0
def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):
  """Apply weight decay and weight noise."""
  if var_list is None:
    var_list = tf.trainable_variables()

  decay_vars = [v for v in var_list]
  noise_vars = [v for v in var_list if "/body/" in v.name]

  weight_decay_loss = weight_decay(hparams.weight_decay, decay_vars)
  if hparams.weight_decay and common_layers.should_generate_summaries():
    tf.summary.scalar("losses/weight_decay", weight_decay_loss)
  weight_noise_ops = weight_noise(hparams.weight_noise, learning_rate,
                                  noise_vars)

  with tf.control_dependencies(weight_noise_ops):
    loss = tf.identity(loss)

  loss += weight_decay_loss
  return loss
예제 #7
0
def weight_noise(noise_rate, learning_rate, var_list):
  """Apply weight noise to vars in var_list."""
  if not noise_rate:
    return [tf.no_op()]

  tf.logging.info("Applying weight noise scaled by learning rate, "
                  "noise_rate: %0.5f", noise_rate)

  noise_ops = []

  for v in var_list:
    with tf.device(v.device):  # pylint: disable=protected-access
      scale = noise_rate * learning_rate * 0.001
      if common_layers.should_generate_summaries():
        tf.summary.scalar("weight_noise_scale", scale)
      noise = tf.truncated_normal(v.shape) * scale
      noise_op = v.assign_add(noise)
      noise_ops.append(noise_op)

  return noise_ops
예제 #8
0
def weight_noise(noise_rate, learning_rate, var_list):
  """Apply weight noise to vars in var_list."""
  if not noise_rate:
    return [tf.no_op()]

  tf.logging.info("Applying weight noise scaled by learning rate, "
                  "noise_rate: %0.5f", noise_rate)

  noise_ops = []

  for v in var_list:
    with tf.device(v.device):  # pylint: disable=protected-access
      scale = noise_rate * learning_rate * 0.001
      if common_layers.should_generate_summaries():
        tf.summary.scalar("weight_noise_scale", scale)
      noise = tf.truncated_normal(v.shape) * scale
      noise_op = v.assign_add(noise)
      noise_ops.append(noise_op)

  return noise_ops
예제 #9
0
def vq_gating(x, num_experts, k, bneck, hparams=None, name="vq_gating"):
    """VQ gating.
  Args:
    x: input Tensor with shape [batch_size, input_size]
    num_experts: an integer
    k: an integer - number of experts per example
    bneck: a bottleneck object
    hparams: optional hparams
    name: an optional string
  Returns:
    gates: a Tensor with shape [batch_size, num_experts]
    load: a Tensor with shape [num_experts]
  """
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):

        if hparams.use_scales:
            scales = tf.get_variable("scales", [num_experts],
                                     tf.float32,
                                     initializer=tf.ones_initializer())
            scales = tf.nn.softmax(scales)
            hparams.scales = scales
        input_size = x.get_shape().as_list()[-1]
        batch_size = common_layers.shape_list(x)[0]

        if k > 1:
            # first project into two dense layers, chop and discretize, and gate
            # TODO(avaswani): Maybe scale the embeddings flowing out of the experts.
            # We might want to do this to match the computation being done by topk
            x = tf.layers.dense(x, input_size * k)
            # x goes from [batch_size, input_size*k] to [batch_size*k, input_size]
            x = tf.reshape(x, [batch_size * k, input_size])
        inputs = tf.expand_dims(x, axis=1)
        inputs = tf.expand_dims(inputs, axis=1)
        # VQ hparams
        hparams.z_size = int(math.log(num_experts, 2))
        hparams.hidden_size = input_size
        hparams.top_k = k
        d = bneck.discrete_bottleneck(inputs)
        centroids = None
        exp_discrete = d["discrete"]
        embed_lookup = d["embed"]
        extra_loss = d["loss"]
        if hparams.residual_centroids:
            centroids = embed_lookup(exp_discrete)  # gives the centroids
        top_k_indices = tf.squeeze(exp_discrete, axis=1)
        tf.summary.histogram("discrete_counts", top_k_indices)
        # if k > 1, then we need to reshape top_k_indices from [batch_size*k, 1]
        # to [batch_size, k]
        if k > 1:
            top_k_indices = tf.reshape(top_k_indices, [batch_size, k])
        # get the top k gates
        top_k_gates = tf.ones([batch_size, k])
        # This will be a `Tensor` of shape `[batch_size, n]`, with zeros in the
        # positions corresponding to all but the top k experts per example.
        gates = _rowwise_unsorted_segment_sum(top_k_gates, top_k_indices,
                                              num_experts)
        # Compute count per expert from the gates.
        # gates has shape [batch_size, num_experts]
        # count per expert has shape [num_experts, 1]
        count_per_expert = tf.reduce_sum(gates, axis=0)
        if hparams.use_scales:
            scale_loss = tf.reduce_mean(tf.to_float(count_per_expert) * scales)
            extra_loss += scale_loss
        if common_layers.should_generate_summaries():
            tf.summary.histogram("vq_loss", extra_loss)
            tf.summary.historgram("scale_loss", scale_loss)
        return gates, extra_loss, centroids
예제 #10
0
def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
    """Minimize loss."""
    loss = weight_decay_and_noise(loss, hparams, learning_rate)
    if hparams.get('shs_regularization', default=None) is not None:
        if hparams.shs_regularization:
            loss = weight_group_hoyer_square(loss, hparams)
    if hparams.get('ssl_regularization', default=None) is not None:
        if hparams.ssl_regularization:
            loss = weight_group_lasso(loss, hparams)
    loss = tf.identity(loss, name="total_loss")
    if variables is None:
        variables = tf.trainable_variables()
    # Print trainable variables.
    log_variable_sizes(variables, verbose=hparams.summarize_vars)
    # Print non-trainable variables.
    non_trainable_variables = list(set(tf.global_variables()) - set(variables))
    log_variable_sizes(non_trainable_variables,
                       tag="Non-trainable variables",
                       verbose=hparams.summarize_vars)
    if hparams.summarize_vars:
        summarize_variables(variables)
        # Summarize non-trainable variables as well
        summarize_variables(non_trainable_variables,
                            tag="Non-trainable variables")
    diet_vars = [
        v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
    ]
    log_variable_sizes(diet_vars,
                       "Diet Variables",
                       verbose=hparams.summarize_vars)
    opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams,
                               use_tpu)
    if use_tpu:
        opt = tf.contrib.tpu.CrossShardOptimizer(opt)
    opt_summaries = []
    if common_layers.should_generate_summaries():
        tf.summary.scalar("learning_rate", learning_rate)
        opt_summaries.append("loss")
        if hparams.summarize_grads:
            tf.logging.info("Summarizing gradients")
            opt_summaries.extend(
                ["gradients", "gradient_norm", "global_gradient_norm"])

    if hparams.clip_grad_norm:
        tf.logging.info("Clipping gradients, norm: %0.5f",
                        hparams.clip_grad_norm)
    if hparams.grad_noise_scale:
        tf.logging.info("Adding noise to gradients, noise scale: %0.5f",
                        hparams.grad_noise_scale)

    train_op = tf.contrib.layers.optimize_loss(
        name="training",
        loss=loss,
        global_step=tf.train.get_or_create_global_step(),
        learning_rate=learning_rate,
        clip_gradients=hparams.clip_grad_norm or None,
        gradient_noise_scale=hparams.grad_noise_scale or None,
        optimizer=opt,
        summaries=opt_summaries,
        colocate_gradients_with_ops=True,
        variables=variables)
    return train_op
예제 #11
0
def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None):
    """Minimize loss."""
    loss = weight_decay_and_noise(loss, hparams, learning_rate)
    loss = tf.identity(loss, name="total_loss")
    if variables is None:
        variables = tf.trainable_variables()
    # Print trainable variables.
    log_variable_sizes(variables, verbose=hparams.summarize_vars)
    # Print non-trainable variables.
    non_trainable_variables = list(set(tf.global_variables()) - set(variables))
    log_variable_sizes(non_trainable_variables,
                       tag="Non-trainable variables",
                       verbose=hparams.summarize_vars)
    if hparams.summarize_vars:
        summarize_variables(variables)
        # Summarize non-trainable variables as well
        summarize_variables(non_trainable_variables,
                            tag="Non-trainable variables")
    diet_vars = [
        v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
    ]
    log_variable_sizes(diet_vars,
                       "Diet Variables",
                       verbose=hparams.summarize_vars)
    opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams,
                               use_tpu)
    if use_tpu:
        opt = tf.contrib.tpu.CrossShardOptimizer(opt)
    if getattr(hparams, "gpu_automatic_mixed_precision", False):
        if use_tpu:
            raise RuntimeError(
                "GPU auto mixed precision cannot be used with TPU")
        elif _mixed_precision_is_enabled(hparams):
            raise RuntimeError(
                "GPU auto mixed precision cannot be used with manual mixed precision"
            )
        else:
            setattr(opt, "_use_locking", "True")
            setattr(opt, "_name", "ConditionalOptimizer")
            opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(
                opt)

    opt_summaries = []
    if common_layers.should_generate_summaries():
        tf.summary.scalar("learning_rate", learning_rate)
        opt_summaries.append("loss")
        if hparams.summarize_grads:
            tf.logging.info("Summarizing gradients")
            opt_summaries.extend(
                ["gradients", "gradient_norm", "global_gradient_norm"])

    if hparams.clip_grad_norm:
        tf.logging.info("Clipping gradients, norm: %0.5f",
                        hparams.clip_grad_norm)
    if hparams.grad_noise_scale:
        tf.logging.info("Adding noise to gradients, noise scale: %0.5f",
                        hparams.grad_noise_scale)

    train_op = tf.contrib.layers.optimize_loss(
        name="training",
        loss=loss,
        global_step=tf.train.get_or_create_global_step(),
        learning_rate=learning_rate,
        clip_gradients=hparams.clip_grad_norm or None,
        gradient_noise_scale=hparams.grad_noise_scale or None,
        optimizer=opt,
        summaries=opt_summaries,
        colocate_gradients_with_ops=True,
        variables=variables)
    return train_op
예제 #12
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.vocab_size["targets"]
    if hasattr(self._hparams, "vocab_divisor"):
      vocab_size += (-vocab_size) % self._hparams.vocab_divisor
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
        or self._encode_on_predict):
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      res_size = common_layers.shape_list(x)[-1]
      b = self.unbottleneck(b, res_size)
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(
            tf.squared_difference(x_stop, b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([x, g], axis=0)
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res, res_gan = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses
예제 #13
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.modality["targets"].top_dimensionality
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
        or self._encode_on_predict):
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      res_size = common_layers.shape_list(x)[-1]
      b = self.unbottleneck(b, res_size)
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x_stop - b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([x, g], axis=0)
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res, res_gan = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses
예제 #14
0
 def decoder(self, x, encoder_layers=None):
   with tf.variable_scope("decoder"):
     hparams = self.hparams
     is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN
     kernel, strides = self._get_kernel_and_strides()
     residual_kernel = (hparams.residual_kernel_height,
                        hparams.residual_kernel_width)
     residual_kernel1d = (hparams.residual_kernel_height, 1)
     residual_kernel = residual_kernel1d if self.is1d else residual_kernel
     residual_conv = tf.layers.conv2d
     if hparams.residual_use_separable_conv:
       residual_conv = tf.layers.separable_conv2d
     # Up-convolutions.
     for i in range(hparams.num_hidden_layers):
       j = hparams.num_hidden_layers - i - 1
       if is_training:
         nomix_p = common_layers.inverse_lin_decay(
             int(hparams.bottleneck_warmup_steps * 0.25 * 2**j)) + 0.01
         if common_layers.should_generate_summaries():
           tf.summary.scalar("nomix_p_%d" % j, nomix_p)
       filters = hparams.hidden_size * 2**j
       filters = min(filters, hparams.max_hidden_size)
       with tf.variable_scope("layer_%d" % i):
         j = hparams.num_hidden_layers - i - 1
         x = tf.layers.conv2d_transpose(
             x,
             filters,
             kernel,
             strides=strides,
             padding="SAME",
             activation=common_layers.belu,
             name="strided")
         y = x
         for r in range(hparams.num_residual_layers):
           residual_filters = filters
           if r < hparams.num_residual_layers - 1:
             residual_filters = int(
                 filters * hparams.residual_filter_multiplier)
           y = residual_conv(
               y,
               residual_filters,
               residual_kernel,
               padding="SAME",
               activation=common_layers.belu,
               name="residual_%d" % r)
         x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout)
         x = common_layers.layer_norm(x, name="ln")
         x = common_attention.add_timing_signal_nd(x)
         if encoder_layers is not None:
           enc_x = encoder_layers[j]
           enc_shape = common_layers.shape_list(enc_x)
           x_mix = x[:enc_shape[0], :enc_shape[1], :enc_shape[2], :]
           if is_training:  # Mix at the beginning of training.
             rand = tf.random_uniform(common_layers.shape_list(x_mix))
             x_mix = tf.where(tf.less(rand, nomix_p), x_mix, enc_x)
           if hparams.gan_loss_factor != 0:
             x_gan = x[enc_shape[0]:, :enc_shape[1], :enc_shape[2], :]
             x = tf.concat([x_mix, x_gan], axis=0)
           else:
             x = x_mix
     return x
예제 #15
0
def autoencoder_body(self, features):
  """ Customized body function for autoencoders acting on continuous images.
  This is based on tensor2tensor.models.research.AutoencoderBasic.body
  and should be compatible with most derived classes.

  The original autoencoder class relies on embedding the channels to a discrete
  vocabulary and defines the loss on that vocab. It's cool and all, but here we
  prefer expressing the reconstruction loss as an actual continuous likelihood
  function.
  """
  hparams = self.hparams
  is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

  output_activation = tf.nn.softplus if hparams.output_activation == 'softplus' else None
  input_shape =  [None, ] + common_layers.shape_list(features["inputs"])[1:]

  if hparams.mode == tf.estimator.ModeKeys.PREDICT:
    # In predict mode, we also define TensorFlow Hub modules for all pieces of
    # the autoencoder
    if hparams.encode_psf and 'psf' in features:
      psf_shape =  [None, ] + common_layers.shape_list(features["psf"])[1:]
    # First build encoder spec
    def make_model_spec():
      input_layer = tf.placeholder(tf.float32, shape=input_shape)
      x = self.embed(tf.expand_dims(input_layer, -1))
      x, encoder_layers = self.encoder(x)
      b, b_loss = self.bottleneck(x)
      hub.add_signature(inputs=input_layer, outputs=b)

    def make_model_spec_psf():
      input_layer = tf.placeholder(tf.float32, shape=input_shape)
      psf_layer = tf.placeholder(tf.float32, shape=psf_shape)
      x = self.embed(tf.expand_dims(input_layer, -1))

      # If we have access to the PSF, we add this information to the encoder
      if hparams.encode_psf and 'psf' in features:
        psf_image = tf.expand_dims(tf.signal.irfft2d(tf.cast(psf_layer[...,0], tf.complex64)), axis=-1)
        # Roll the image to undo the fftshift, assuming x1 zero padding and x2 subsampling
        psf_image = tf.roll(psf_image, shift=[input_shape[1], input_shape[2]], axis=[1,2])
        psf_image = tf.image.resize_with_crop_or_pad(psf_image, input_shape[1], input_shape[2])
        net_psf = tf.layers.conv2d(psf_image,
                                   hparams.hidden_size // 4, 5,
                                   padding='same', name="psf_embed_1")
        net_psf = common_layers.layer_norm(net_psf, name="psf_norm")
        x, encoder_layers = self.encoder(tf.concat([x, net_psf], axis=-1))
      else:
        x, encoder_layers = self.encoder(x)
      b, b_loss = self.bottleneck(x)
      hub.add_signature(inputs={'input':input_layer, 'psf':psf_layer}, outputs=b)

    spec = hub.create_module_spec(make_model_spec_psf if hparams.encode_psf else make_model_spec, drop_collections=['checkpoints'])
    encoder = hub.Module(spec, name="encoder_module")
    hub.register_module_for_export(encoder, "encoder")

    if hparams.encode_psf:
      code = encoder({'input':features["inputs"], 'psf': features['psf']})
    else:
      code = encoder(features["inputs"])
    b_shape = [None, ] + common_layers.shape_list(code)[1:]
    res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
    res_size = min(res_size, hparams.max_hidden_size)

    # Second build decoder spec
    def make_model_spec():
      input_layer = tf.placeholder(tf.float32, shape=b_shape)
      x = self.unbottleneck(input_layer, res_size)
      x = self.decoder(x, None)
      reconstr = tf.layers.dense(x, input_shape[-1], name="autoencoder_final",
                                 activation=output_activation)
      hub.add_signature(inputs=input_layer, outputs=reconstr)
      hub.attach_message("stamp_size", tf.train.Int64List(value=[hparams.problem_hparams.img_len]))
      try:
        hub.attach_message("pixel_size", tf.train.FloatList(value=[hparams.problem_hparams.pixel_scale[res] for res in hparams.problem_hparams.resolutions]))
      except AttributeError:
        hub.attach_message("pixel_size", tf.train.FloatList(value=[hparams.problem_hparams.pixel_scale]))
    spec = hub.create_module_spec(make_model_spec, drop_collections=['checkpoints'])
    decoder = hub.Module(spec, name="decoder_module")
    hub.register_module_for_export(decoder, "decoder")

    reconstr = decoder(code)
    return reconstr , {"bottleneck_loss": 0.0}

  encoder_layers = None
  self.is1d = hparams.sample_width == 1
  if (hparams.mode != tf.estimator.ModeKeys.PREDICT
      or self._encode_on_predict):
    labels = features["targets_raw"]
    labels_shape = common_layers.shape_list(labels)

    shape = common_layers.shape_list(labels)
    with tf.variable_scope('encoder_module'):
      x = self.embed(tf.expand_dims(labels, -1))

    if shape[2] == 1:
      self.is1d = True

    # Run encoder.
    with tf.variable_scope('encoder_module'):
      # If we have access to the PSF, we add this information to the encoder
      # Note that we only support single band images so far...
      if hparams.encode_psf and 'psf' in features:
        psf_image = tf.expand_dims(tf.signal.irfft2d(tf.cast(features['psf'][...,0], tf.complex64)), axis=-1)
        # Roll the image to undo the fftshift, assuming x1 zero padding and x2 subsampling
        psf_image = tf.roll(psf_image, shift=[input_shape[1], input_shape[2]], axis=[1,2])
        psf_image = tf.image.resize_with_crop_or_pad(psf_image, input_shape[1], input_shape[2])
        net_psf = tf.layers.conv2d(psf_image,
                                   hparams.hidden_size // 4, 5,
                                   padding='same', name="psf_embed_1")
        net_psf = common_layers.layer_norm(net_psf, name="psf_norm")
        x, encoder_layers = self.encoder(tf.concat([x, net_psf], axis=-1))
      else:
        x, encoder_layers = self.encoder(x)

    # Bottleneck.
    with tf.variable_scope('encoder_module'):
      b, b_loss = self.bottleneck(x)

    xb_loss = 0.0
    b_shape = common_layers.shape_list(b)
    self._cur_bottleneck_tensor = b
    res_size = common_layers.shape_list(x)[-1]
    with tf.variable_scope('decoder_module'):
      b = self.unbottleneck(b, res_size)
    if not is_training:
      x = b
    else:
      l = 2**hparams.num_hidden_layers
      warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
      nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
      if common_layers.should_generate_summaries():
        tf.summary.scalar("nomix_p_bottleneck", nomix_p)
      rand = tf.random_uniform(common_layers.shape_list(x))
      # This is the distance between b and x. Having this as loss helps learn
      # the bottleneck function, but if we back-propagated to x it would be
      # minimized by just setting x=0 and b=0 -- so we don't want too much
      # of the influence of this, and we stop-gradient to not zero-out x.
      x_stop = tf.stop_gradient(x)
      xb_loss = tf.reduce_mean(tf.reduce_sum(
          tf.squared_difference(x_stop, b), axis=-1))
      # To prevent this loss from exploding we clip at 1, but anneal clipping.
      clip_max = 1.0 / common_layers.inverse_exp_decay(
          warm_step, min_value=0.001)
      xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
      xb_loss *= clip_max / xb_clip
      x = tf.where(tf.less(rand, nomix_p), b, x)
  else:
    if self._cur_bottleneck_tensor is None:
      b = self.sample()
    else:
      b = self._cur_bottleneck_tensor
    self._cur_bottleneck_tensor = b
    res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
    res_size = min(res_size, hparams.max_hidden_size)

    with tf.variable_scope('decoder_module'):
      x = self.unbottleneck(b, res_size)

  # Run decoder.
  with tf.variable_scope('decoder_module'):
    x = self.decoder(x, encoder_layers)

  # Cut to the right size and mix before returning.
  res = x
  if hparams.mode != tf.estimator.ModeKeys.PREDICT:
    res = x[:, :shape[1], :shape[2], :]

  with tf.variable_scope('decoder_module'):
    reconstr = tf.layers.dense(res, shape[-1], name="autoencoder_final",
                               activation=output_activation)

  # We apply an optional apodization of the output before taking the
  if hparams.output_apodization > 0:
    nx = reconstr.get_shape().as_list()[1]
    alpha = 2 * hparams.output_apodization / nx
    from scipy.signal.windows import tukey
    # Create a tukey window
    w = tukey(nx, alpha)
    w = np.outer(w,w).reshape((1, nx, nx,1)).astype('float32')
    # And penalize non zero things at the border
    apo_loss = tf.reduce_mean(tf.reduce_sum(((1.- w)*reconstr)**2, axis=[1,2,3]))
  else:
    w = 1.0
    apo_loss = 0.

  # We apply the window
  reconstr = reconstr * w

  # Optionally regularizes further the output
  # Anisotropic TV:
  tv = tf.reduce_mean(tf.image.total_variation(reconstr))
  # Smoothed Isotropic TV:
  #im_dx, im_dy = tf.image.image_gradients(reconstr)
  #tv = tf.reduce_sum(tf.sqrt(im_dx**2 + im_dy**2 + 1e-6), axis=[1,2,3])
  #tv = tf.reduce_mean(tv)

  image_summary("without_psf",tf.reshape(reconstr, labels_shape))
  # Apply channel-wise convolution with the PSF if requested
  if hparams.apply_psf and 'psf' in features:
    output_list = []
    for i in range(shape[3]):
      output_list.append(tf.squeeze(convolve(tf.expand_dims(reconstr[...,i],-1), tf.cast(features['psf'][...,i], tf.complex64),
                          zero_padding_factor=1)))
    reconstr = tf.stack(output_list,axis=-1)
    reconstr = tf.reshape(reconstr,shape)

  # Losses.
  losses = {
      "bottleneck_extra": b_loss,
      "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss,
      "total_variation": hparams.total_variation_loss * tv,
      "apodization_loss": hparams.apodization_loss * apo_loss,
  }

  loglik = loglikelihood_fn(labels, reconstr, features, hparams)
  targets_loss = tf.reduce_mean(- loglik)

  tf.summary.scalar("negloglik", targets_loss)
  tf.summary.scalar("bottleneck_loss", b_loss)

  # Compute final loss
  losses["training"] = targets_loss + b_loss + hparams.bottleneck_l2_factor * xb_loss + hparams.total_variation_loss * tv +  hparams.apodization_loss * apo_loss
  logits = tf.reshape(reconstr, labels_shape)

  image_summary("ae", reconstr)
  image_summary("input", labels)

  return logits, losses
예제 #16
0
def graph_attention(q,
                    k,
                    v,
                    bias,
                    dropout_rate=0.0,
                    image_shapes=None,
                    name=None,
                    make_image_summary=True,
                    save_weights_to=None,
                    dropout_broadcast_dims=None,
                    adjacency_matrix=None,
                    num_edge_types=5):
    """graph attention.

  Args:
    q: a Tensor with shape [batch, heads, length_q, depth_k]
    k: a Tensor with shape [batch, heads, length_kv, depth_k]
    v: a Tensor with shape [batch, heads, length_kv, depth_v]
    bias: bias Tensor (see attention_bias())
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    make_image_summary: True if you want an image summary.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    adjacency_matrix: optional matrix of [batch, length, length] ids indicating
      edge type
    num_edge_types: an int indicating number of edge types
  Returns:
    A Tensor of shape [batch, length, depth(q)]
  """
    with tf.variable_scope(name,
                           default_name="dot_product_attention",
                           values=[q, k, v]) as scope:
        # [batch, num_heads, query_length, memory_length]
        logits = tf.matmul(q, k, transpose_b=True)
        if adjacency_matrix is not None:
            key_head_depth = common_layers.shape_list(q)[-1]
            adjacency_vectors = make_edge_vectors(adjacency_matrix,
                                                  num_edge_types,
                                                  key_head_depth,
                                                  name=name)
            # transposing q to be [batch, length_q, heads, depth_k]
            # to allow for matmul with [batch, length_q, length_q, depth_k]
            q_t = tf.transpose(q, [0, 2, 1, 3])
            adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True)
            logits += tf.transpose(adj_logits, [0, 2, 1, 3])
            # [batch, depth, num_nodes, num_nodes]
        if bias is not None:
            logits += bias
        weights = tf.nn.softmax(logits, name="attention_weights")
        if save_weights_to is not None:
            save_weights_to[scope.name] = weights
        # dropping out the attention links for each of the heads
        weights = common_layers.dropout_with_broadcast_dims(
            weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
        if common_layers.should_generate_summaries() and make_image_summary:
            common_attention.attention_image_summary(weights, image_shapes)
        return tf.matmul(weights, v)
예제 #17
0
파일: mtsa.py 프로젝트: taoshen58/mtsa
def dot_product_attention_mtsa(
    q,
    k,
    v,
    bias,
    dropout_rate=0.0,
    image_shapes=None,
    name=None,
    make_image_summary=True,
    save_weights_to=None,
    dropout_broadcast_dims=None,
    use_k_mtsa=True,
    afn_extra='none',
    afn_dot='exp',
    afn_multi='exp',
    bias_start=0.,
    bi_direction=False,
):
    """Dot-product attention.

  Args:
    q: Tensor with shape [..., length_q, depth_k].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    bias: bias Tensor (see attention_bias())
    dropout_rate: a float.
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    make_image_summary: True if you want an image summary.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims: an optional list of integers less than rank of q.
      Specifies in which dimensions to broadcast the dropout decisions.

  Returns:
    Tensor with shape [..., length_q, depth_v].
  """
    print("!!!!!dot_product_attention_mtsa!!!!!")
    with tf.variable_scope(name,
                           default_name="dot_product_attention",
                           values=[q, k, v]) as scope:
        # get dim
        dim_q = q.get_shape().as_list()[-1]
        dim_k = k.get_shape().as_list()[-1]
        dim_v = v.get_shape().as_list()[-1]
        # prepare
        multi_logits_scale_factor = 1. / math.sqrt(
            dim_v) if afn_multi.startswith('scaled') else 1.
        afn_extra, afn_dot, afn_multi = afn_name2fn(afn_extra), afn_name2fn(
            afn_dot), afn_name2fn(afn_multi)
        # if bias is not None:
        #   inp_mask_1d = tf.to_float(tf.equal(bias, 0.))  # bs,1,1,vl
        #   inp_mask_1d = tf.transpose(inp_mask_1d, [0, 1, 3, 2])   # bs,1,vl,1
        # else:
        #   inp_mask_1d = None

        # token2token self attention
        dot_logits = tf.matmul(q, k, transpose_b=True)  # bs,hd,ql,vl
        if bias is not None:
            bias = common_layers.cast_like(bias, dot_logits)  # 1/bs,1,ql/1,vl
            dot_logits += bias
        e_dot_logits = afn_dot(dot_logits)  # bs,hd,ql,vl
        if bi_direction:
            head_num = v.get_shape().as_list()[1]
            ql, vl = tf.shape(q)[-2], tf.shape(v)[-2]
            assert head_num is not None
            assert head_num % 2 == 0
            ones_mat = tf.ones([ql, vl], tf.float32)
            mul_mask_fw = tf.matrix_band_part(ones_mat, -1,
                                              0)  #  Lower triangular part.
            mul_mask_bw = tf.matrix_band_part(ones_mat, 0,
                                              -1)  #  Upper triangular part.
            mul_mask_fw_tile = tf.tile(tf.expand_dims(mul_mask_fw, 0),
                                       [head_num // 2, 1, 1])
            mul_mask_bw_tile = tf.tile(tf.expand_dims(mul_mask_bw, 0),
                                       [head_num // 2, 1, 1])
            mul_mask = tf.expand_dims(tf.concat(
                [mul_mask_fw_tile, mul_mask_bw_tile], axis=0),
                                      axis=0)
            e_dot_logits *= mul_mask

        # source2token self-attention
        multi_logits = multi_head_dense_layer(
            k if use_k_mtsa else v, dim_v, True,
            bias_start if afn_extra is None else 0., 'multi_logits1')
        if afn_extra is not None:  # use one extra layer for multi-dim
            multi_logits = multi_head_dense_layer(afn_extra(multi_logits),
                                                  dim_v, True, bias_start,
                                                  'multi_logits2')
        e_multi_logits = afn_multi(multi_logits *
                                   multi_logits_scale_factor)  # bs,hd,vl,vd
        # if inp_mask_1d is not None:  # use mask for exp_logits
        #   e_multi_logits *= inp_mask_1d

        # mtsa
        accum_z_deno = tf.matmul(e_dot_logits, e_multi_logits)  # bs,hd,ql,vd
        accum_z_deno = tf.where(  # in case of NaN and Inf
            tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)),
            accum_z_deno, tf.ones_like(accum_z_deno))

        # attention dropout
        e_dot_logits = common_layers.dropout_with_broadcast_dims(
            e_dot_logits,
            math.sqrt(1. - dropout_rate),
            broadcast_dims=dropout_broadcast_dims)
        e_multi_logits = common_layers.dropout_with_broadcast_dims(
            e_multi_logits,
            math.sqrt(1. - dropout_rate),
            broadcast_dims=dropout_broadcast_dims)
        rep_mul_score = v * e_multi_logits  # bs,hd,vl,vd
        accum_rep_mul_score = tf.matmul(e_dot_logits,
                                        rep_mul_score)  # bs,hd,ql,vd
        # calculate the final attention results
        attn_res = accum_rep_mul_score / accum_z_deno
        # if inp_mask_1d is not None:  # use mask for output
        #   attn_res *= inp_mask_1d

        # ============ for vis =======
        weights = e_dot_logits / (tf.reduce_sum(
            e_dot_logits, axis=-1, keepdims=True, name="attention_weights") +
                                  0.00001)
        if save_weights_to is not None:
            save_weights_to[scope.name] = weights
            save_weights_to[scope.name + "/logits"] = dot_logits
        if common_layers.should_generate_summaries() and make_image_summary:
            common_attention.attention_image_summary(weights, image_shapes)
        return attn_res
예제 #18
0
def dot_product_area_attention(q,
                               k,
                               v,
                               bias,
                               dropout_rate=0.0,
                               image_shapes=None,
                               name=None,
                               attention_image_summary=None,
                               save_weights_to=None,
                               dropout_broadcast_dims=None,
                               max_area_width=1,
                               max_area_height=1,
                               memory_height=1,
                               area_key_mode="mean",
                               area_value_mode="sum",
                               top_k_areas=0,
                               area_temperature=1.0,
                               training=True):
    """Dot-product area attention.

  Args:
    q: Tensor with shape [..., length_q, depth_k].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    bias: bias Tensor (see attention_bias())
    dropout_rate: a float.
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    attention_image_summary: the callback for making image summary of attention.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims: an optional list of integers less than rank of q.
      Specifies in which dimensions to broadcast the dropout decisions.
    max_area_width: the max width allowed for an area.
    max_area_height: the max height allowed for an area.
    memory_height: the height of the memory.
    area_key_mode: the mode for computing area keys, which can be "mean",
      "concat", "sum", "sample_concat", and "sample_sum".
    area_value_mode: the mode for computing area values, which can be either
      "mean", or "sum".
    top_k_areas: Use the top key areas for attention.
    area_temperature: the temperature for attention softmax.
    training: indicating if it is in the training mode.
  Returns:
    Tensor with shape [..., length_q, depth_v].
  """

    tf.logging.info(
        "dot_product_area_attention: "
        "area_h=%d, area_w=%d, mem_h=%d, "
        "area_key_mode=%s, area_value_mode=%s, "
        "area_temperature=%f", max_area_height, max_area_width, memory_height,
        area_key_mode, area_value_mode, area_temperature)
    with tf.variable_scope(name,
                           default_name="dot_product_area_attention",
                           values=[q, k, v]) as scope:
        mem_shape = common_layers.shape_list(k)
        batch_size = mem_shape[0]
        head_size = mem_shape[1]
        length = mem_shape[2]
        depth = mem_shape[3]
        k_area = compute_area_key(tf.reshape(k, [-1, length, depth]),
                                  max_area_width=max_area_width,
                                  max_area_height=max_area_height,
                                  height=memory_height,
                                  mode=area_key_mode,
                                  training=training)
        if area_value_mode == "mean":
            v_area, _, _, _, _ = compute_area_features(
                tf.reshape(v, [-1, length, depth]),
                max_area_width=max_area_width,
                max_area_height=max_area_height,
                height=memory_height)
        elif area_value_mode == "max":
            v_area, _, _ = basic_pool(tf.reshape(v, [-1, length, depth]),
                                      max_area_width=max_area_width,
                                      max_area_height=max_area_height,
                                      height=memory_height,
                                      fn=tf.reduce_max)
        elif area_value_mode == "sum":
            _, _, v_area, _, _ = compute_area_features(
                tf.reshape(v, [-1, length, depth]),
                max_area_width=max_area_width,
                max_area_height=max_area_height,
                height=memory_height)
        else:
            raise ValueError("Unsupported area value mode=%s" %
                             area_value_mode)
        k = tf.reshape(k_area, [batch_size, head_size, -1, depth])
        v = tf.reshape(v_area, [batch_size, head_size, -1, depth])
        logits = tf.matmul(q, k,
                           transpose_b=True)  # [..., length_q, length_kv]
        if bias is not None:
            bias = common_layers.cast_like(bias, logits)
            with tf.name_scope("compute_area_att_bias", values=[bias]):
                bias_shape = common_layers.shape_list(bias)
                mem_length = bias_shape[-1]
                bias_values = tf.reshape(tf.to_float(tf.less(bias, -1)),
                                         [-1, mem_length, 1])
                _, _, padding_sum, _, _ = compute_area_features(
                    bias_values,
                    max_area_width=max_area_width,
                    max_area_height=max_area_height,
                    height=memory_height)
                bias = tf.where(tf.cast(tf.to_int32(padding_sum), tf.bool),
                                tf.fill(tf.shape(padding_sum), -np.inf),
                                tf.zeros_like(padding_sum, dtype=tf.float32))
                bias = tf.reshape(
                    bias, [bias_shape[0], bias_shape[1], bias_shape[2], -1])
            logits += bias
        logits = logits / area_temperature
        weights = tf.nn.softmax(logits, name="attention_weights")
        if top_k_areas > 0:
            tf.logging.info("area_attention top_k_areas=%d", top_k_areas)
            top_k = tf.minimum(
                common_layers.shape_list(weights)[-1], top_k_areas)
            top_weights, _ = tf.nn.top_k(weights, k=top_k)
            min_values = tf.reduce_min(top_weights, -1, keepdims=True)
            weights = tf.where(tf.greater_equal(weights, min_values), weights,
                               tf.zeros_like(weights))
            weights = tf.div(weights, tf.reduce_sum(weights, -1,
                                                    keepdims=True))
        if save_weights_to is not None:
            save_weights_to[scope.name] = weights
            save_weights_to[scope.name + "/logits"] = logits
        # Drop out attention links for each head.
        weights = common_layers.dropout_with_broadcast_dims(
            weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
        if common_layers.should_generate_summaries(
        ) and attention_image_summary:
            attention_image_summary(weights, image_shapes)
        return tf.matmul(weights, v)
예제 #19
0
def noisy_top_k_gating(x,
                       num_experts,
                       train,
                       k=2,
                       initializer=tf.zeros_initializer(),
                       noisy_gating=True,
                       noise_epsilon=1e-2,
                       name=None):
    """Noisy top-k gating.
  See paper: https://arxiv.org/abs/1701.06538.
  Args:
    x: input Tensor with shape [batch_size, input_size]
    num_experts: an integer
    train: a boolean - we only add noise at training time.
    k: an integer - number of experts per example
    initializer: an initializer
    noisy_gating: a boolean
    noise_epsilon: a float
    name: an optional string
  Returns:
    gates: a Tensor with shape [batch_size, num_experts]
    load: a Tensor with shape [num_experts]
  """
    with tf.variable_scope(name, default_name="noisy_top_k_gating"):
        input_size = x.get_shape().as_list()[-1]
        w_gate = tf.get_variable("w_gate", [input_size, num_experts],
                                 tf.float32, initializer)
        if noisy_gating:
            w_noise = tf.get_variable("w_noise", [input_size, num_experts],
                                      tf.float32, initializer)
        clean_logits = tf.matmul(x, w_gate)
        if noisy_gating:
            raw_noise_stddev = tf.matmul(x, w_noise)
            noise_stddev = (
                (tf.nn.softplus(raw_noise_stddev) + noise_epsilon) *
                (tf.to_float(train)))
            noisy_logits = clean_logits + (
                tf.random_normal(tf.shape(clean_logits)) * noise_stddev)
            logits = noisy_logits
            if common_layers.should_generate_summaries():
                tf.summary.histogram("noisy_logits", noisy_logits)
                tf.summary.histogram("noise_stddev", noise_stddev)
        else:
            logits = clean_logits
        top_logits, top_indices = _my_top_k(logits, min(k + 1, num_experts))
        # top k logits has shape [batch, k]
        top_k_logits = tf.slice(top_logits, [0, 0], [-1, k])
        top_k_indices = tf.slice(top_indices, [0, 0], [-1, k])
        top_k_gates = tf.nn.softmax(top_k_logits)
        # This will be a `Tensor` of shape `[batch_size, n]`, with zeros in the
        # positions corresponding to all but the top k experts per example.
        gates = _rowwise_unsorted_segment_sum(top_k_gates, top_k_indices,
                                              num_experts)
        if noisy_gating and k < num_experts:
            load = tf.reduce_sum(
                _prob_in_top_k(clean_logits, noisy_logits, noise_stddev,
                               top_logits, k), 0)
        else:
            load = _gates_to_load(gates)
        if common_layers.should_generate_summaries():
            tf.summary.histogram("importance", tf.reduce_sum(gates, 0))
            tf.summary.histogram("load", load)
        return gates, load
예제 #20
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
        or self._encode_on_predict):
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.expand_dims(labels, axis=-1)
      x = self.embed(x)
      target_codes = x
      print(x)
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      res_size = common_layers.shape_list(x)[-1]
      b = self.unbottleneck(b, res_size)
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(
            tf.squared_difference(x_stop, b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      reconstr = tf.layers.dense(res, self.num_channels, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    reconstr = tf.layers.dense(res, self.num_channels, name="autoencoder_final")
    reconstr = tf.reshape(reconstr, labels_shape)

    targets_loss = self.reconstruction_loss(reconstr, labels)
    losses["training"] = targets_loss

    self.image_summary("inputs", labels)
    self.image_summary("ae", reconstr)
    return reconstr, losses
예제 #21
0
    def body(self, features):
        hparams = self.hparams
        is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            labels = features["targets_raw"]
            vocab_size = self._problem_hparams.target_modality.top_dimensionality
            shape = common_layers.shape_list(labels)
            x = tf.one_hot(labels, vocab_size)
            x = self.embed(x)
            target_codes = x
            is1d = shape[2] == 1
            self.is1d = is1d
            # Run encoder.
            x, encoder_layers = self.encoder(x)
            # Bottleneck.
            b, b_loss = self.bottleneck(x)
            xb_loss = 0.0
            b_shape = common_layers.shape_list(b)
            self._cur_bottleneck_tensor = b
            b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
            if not is_training:
                x = b
            else:
                l = 2**hparams.num_hidden_layers
                warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
                nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
                if common_layers.should_generate_summaries():
                    tf.summary.scalar("nomix_p_bottleneck", nomix_p)
                rand = tf.random_uniform(common_layers.shape_list(x))
                # This is the distance between b and x. Having this as loss helps learn
                # the bottleneck function, but if we back-propagated to x it would be
                # minimized by just setting x=0 and b=0 -- so we don't want too much
                # of the influence of this, and we stop-gradient to not zero-out x.
                x_stop = tf.stop_gradient(x)
                xb_loss = tf.reduce_mean(
                    tf.reduce_sum(tf.square(x_stop - b), axis=-1))
                # To prevent this loss from exploding we clip at 1, but anneal clipping.
                clip_max = 1.0 / common_layers.inverse_exp_decay(
                    warm_step, min_value=0.001)
                xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
                xb_loss *= clip_max / xb_clip
                x = tf.where(tf.less(rand, nomix_p), b, x)
            if hparams.gan_loss_factor != 0.0:
                # Add a purely sampled batch on which we'll compute the GAN loss.
                g = self.unbottleneck(self.sample(shape=b_shape),
                                      common_layers.shape_list(x)[-1],
                                      reuse=True)
                x = tf.concat([g, x], axis=0)
                encoder_layers = [
                    tf.concat([l, l], axis=0) for l in encoder_layers
                ]
        else:
            if self._cur_bottleneck_tensor is None:
                b = self.sample()
            else:
                b = self._cur_bottleneck_tensor
            res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
            res_size = min(res_size, hparams.max_hidden_size)
            x = self.unbottleneck(b, res_size)
        # Run decoder.
        x = self.decoder(x, encoder_layers)
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            return x, {"bottleneck_loss": 0.0}
        # Cut to the right size and mix before returning.
        res = x[:, :shape[1], :shape[2], :]

        # Final dense layer.
        res = tf.layers.dense(res,
                              self.num_channels * hparams.hidden_size,
                              name="res_dense")

        output_shape = common_layers.shape_list(res)[:-1] + [
            self.num_channels, self.hparams.hidden_size
        ]
        res = tf.reshape(res, output_shape)

        if hparams.gan_loss_factor != 0.0:
            res_gan, res = tf.split(res, 2, axis=0)

        # Losses.
        losses = {
            "bottleneck_extra": b_loss,
            "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
        }

        if hparams.use_vq_loss:
            vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps * 1.2,
                min_value=hparams.vq_temperature * 2)
            if hparams.mode != tf.estimator.ModeKeys.TRAIN:
                vq_temperature = None
            with tf.variable_scope("vq_loss"):
                (reconstr, _, target_codes, code_loss,
                 targets_loss) = discretization.vq_loss(
                     res, labels, vocab_size, temperature=vq_temperature)
            losses["code_loss"] = code_loss * hparams.code_loss_factor
            losses["training"] = targets_loss
        else:
            reconstr = tf.layers.dense(res,
                                       vocab_size,
                                       name="autoencoder_final")
            targets_loss = tf.losses.sparse_softmax_cross_entropy(
                logits=reconstr, labels=labels)
            losses["training"] = targets_loss

        # GAN losses.
        if hparams.gan_loss_factor != 0.0:
            update_means_factor = common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps, min_value=0.0001)
            if hparams.use_vq_loss:
                with tf.variable_scope("vq_loss", reuse=True):
                    update_means = tf.less(tf.random_uniform([]),
                                           update_means_factor)
                    reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
                        res_gan,
                        labels,
                        vocab_size,
                        do_update=update_means,
                        temperature=vq_temperature)
                    code_loss_gan *= hparams.code_loss_factor * update_means_factor
                    losses["code_loss_gan"] = code_loss_gan
            else:
                reconstr_gan = tf.layers.dense(res_gan,
                                               vocab_size,
                                               name="autoencoder_final",
                                               reuse=True)
                reconstr_gan = tf.nn.log_softmax(reconstr_gan)
                if is_training and hparams.gumbel_temperature > 0.0:
                    gumbel_samples = discretization.gumbel_sample(
                        common_layers.shape_list(reconstr_gan))
                    gumbel_samples *= hparams.gumbel_noise_factor
                    reconstr_gan += gumbel_samples
                    reconstr_sample = latent_layers.multinomial_sample(
                        reconstr_gan, temperature=hparams.gumbel_temperature)
                    reconstr_gan = tf.nn.softmax(reconstr_gan /
                                                 hparams.gumbel_temperature)
                else:
                    reconstr_sample = tf.argmax(reconstr_gan, axis=-1)
                    reconstr_gan = tf.nn.softmax(reconstr_gan /
                                                 0.1)  # Sharpen a bit.
                # Use 1-hot forward, softmax backward.
                reconstr_hot = tf.one_hot(reconstr_sample, vocab_size)
                reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan)
                # Embed to codes.
                gan_codes = self.embed(reconstr_gan)

        # Add GAN loss if requested.
        gan_loss = 0.0
        if hparams.gan_loss_factor != 0.0:
            self.image_summary("gan", reconstr_gan)

            def discriminate(x):
                return self.discriminator(x, is_training=is_training)

            tc_shape = common_layers.shape_list(target_codes)
            if len(tc_shape) > 4:
                target_codes = tf.reshape(
                    target_codes,
                    tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
                gan_codes = tf.reshape(
                    gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
            gan_lr = common_layers.inverse_exp_decay(
                hparams.gan_codes_warmup_steps * 1.5)
            rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
            gan_loss = common_layers.sliced_gan_loss(
                target_codes, rev_grad_gan_codes, discriminate,
                self.hparams.num_sliced_vecs)
            gan_loss *= hparams.gan_loss_factor * update_means_factor
            losses["gan_loss"] = -gan_loss

        self.image_summary("ae", reconstr)
        logits = reconstr
        return logits, losses
def graph_attention(q,
                    k,
                    v,
                    bias,
                    dropout_rate=0.0,
                    image_shapes=None,
                    name=None,
                    make_image_summary=True,
                    save_weights_to=None,
                    dropout_broadcast_dims=None,
                    adjacency_matrix=None,
                    num_edge_types=5):
  """graph attention.

  Args:
    q: a Tensor with shape [batch, heads, length_q, depth_k]
    k: a Tensor with shape [batch, heads, length_kv, depth_k]
    v: a Tensor with shape [batch, heads, length_kv, depth_v]
    bias: bias Tensor (see attention_bias())
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    make_image_summary: True if you want an image summary.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    adjacency_matrix: optional matrix of [batch, length, length] ids indicating
      edge type
    num_edge_types: an int indicating number of edge types
  Returns:
    A Tensor of shape [batch, length, depth(q)]
  """
  with tf.variable_scope(
      name, default_name="dot_product_attention", values=[q, k, v]) as scope:
    # [batch, num_heads, query_length, memory_length]
    logits = tf.matmul(q, k, transpose_b=True)
    if adjacency_matrix is not None:
      key_head_depth = common_layers.shape_list(q)[-1]
      adjacency_vectors = make_edge_vectors(
          adjacency_matrix,
          num_edge_types,
          key_head_depth,
          name=name)
      # transposing q to be [batch, length_q, heads, depth_k]
      # to allow for matmul with [batch, length_q, length_q, depth_k]
      q_t = tf.transpose(q, [0, 2, 1, 3])
      adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True)
      logits += tf.transpose(adj_logits, [0, 2, 1, 3])
      # [batch, depth, num_nodes, num_nodes]
    if bias is not None:
      logits += bias
    weights = tf.nn.softmax(logits, name="attention_weights")
    if save_weights_to is not None:
      save_weights_to[scope.name] = weights
    # dropping out the attention links for each of the heads
    weights = common_layers.dropout_with_broadcast_dims(
        weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
    if common_layers.should_generate_summaries() and make_image_summary:
      common_attention.attention_image_summary(weights, image_shapes)
    return tf.matmul(weights, v)
예제 #23
0
def autoencoder_body(self, features):
    """ Customized body function for autoencoders acting on continuous images.
  This is based on tensor2tensor.models.research.AutoencoderBasic.body
  and should be compatible with most derived classes.

  The original autoencoder class relies on embedding the channels to a discrete
  vocabulary and defines the loss on that vocab. It's cool and all, but here we
  prefer expressing the reconstruction loss as an actual continuous likelihood
  function.
  """
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN

    output_activation = tf.nn.softplus if hparams.output_activation == 'softplus' else None
    input_shape = [
        None,
    ] + common_layers.shape_list(features["inputs"])[1:]

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        # In predict mode, we also define TensorFlow Hub modules for all pieces of
        # the autoencoder
        # First build encoder spec
        def make_model_spec():
            input_layer = tf.placeholder(tf.float32, shape=input_shape)
            x = self.embed(tf.expand_dims(input_layer, -1))
            x, encoder_layers = self.encoder(x)
            b, b_loss = self.bottleneck(x)
            hub.add_signature(inputs=input_layer, outputs=b)

        def make_model_spec_psf():
            input_layer = tf.placeholder(tf.float32, shape=input_shape)
            psf_layer = tf.placeholder(tf.float32, shape=input_shape)
            x = self.embed(tf.expand_dims(input_layer, -1))

            # If we have access to the PSF, we add this information to the encoder
            if hparams.encode_psf and 'psf' in features:
                net_psf = tf.layers.conv2d(psf_layer,
                                           hparams.hidden_size // 4,
                                           5,
                                           padding='same',
                                           name="psf_embed_1")
                net_psf = common_layers.layer_norm(net_psf, name="psf_norm")
                x, encoder_layers = self.encoder(
                    tf.concat([x, net_psf], axis=-1))
            else:
                x, encoder_layers = self.encoder(x)
            b, b_loss = self.bottleneck(x)
            hub.add_signature(inputs={
                'input': input_layer,
                'psf': psf_layer
            },
                              outputs=b)

        spec = hub.create_module_spec(
            make_model_spec_psf if hparams.encode_psf else make_model_spec,
            drop_collections=['checkpoints'])
        encoder = hub.Module(spec, name="encoder_module")
        hub.register_module_for_export(encoder, "encoder")

        if hparams.encode_psf:
            code = encoder({
                'input': features["inputs"],
                'psf': features['psf']
            })
        else:
            code = encoder(features["inputs"])
        b_shape = [
            None,
        ] + common_layers.shape_list(code)[1:]
        res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
        res_size = min(res_size, hparams.max_hidden_size)

        # Second build decoder spec
        def make_model_spec():
            input_layer = tf.placeholder(tf.float32, shape=b_shape)
            x = self.unbottleneck(input_layer, res_size)
            x = self.decoder(x, None)
            reconstr = tf.layers.dense(x,
                                       self.num_channels,
                                       name="autoencoder_final",
                                       activation=output_activation)
            hub.add_signature(inputs=input_layer, outputs=reconstr)
            hub.attach_message(
                "stamp_size",
                tf.train.Int64List(value=[hparams.problem_hparams.img_len]))
            hub.attach_message(
                "pixel_size",
                tf.train.FloatList(
                    value=[hparams.problem_hparams.pixel_scale]))

        spec = hub.create_module_spec(make_model_spec,
                                      drop_collections=['checkpoints'])
        decoder = hub.Module(spec, name="decoder_module")
        hub.register_module_for_export(decoder, "decoder")

        reconstr = decoder(code)
        return reconstr, {"bottleneck_loss": 0.0}

    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if (hparams.mode != tf.estimator.ModeKeys.PREDICT
            or self._encode_on_predict):
        labels = features["targets_raw"]
        labels_shape = common_layers.shape_list(labels)

        shape = common_layers.shape_list(labels)
        with tf.variable_scope('encoder_module'):
            x = self.embed(tf.expand_dims(labels, -1))

        if shape[2] == 1:
            self.is1d = True

        # Run encoder.
        with tf.variable_scope('encoder_module'):
            # If we have access to the PSF, we add this information to the encoder
            if hparams.encode_psf and 'psf' in features:
                net_psf = tf.layers.conv2d(features['psf'],
                                           hparams.hidden_size // 4,
                                           5,
                                           padding='same',
                                           name="psf_embed_1")
                net_psf = common_layers.layer_norm(net_psf, name="psf_norm")
                x, encoder_layers = self.encoder(
                    tf.concat([x, net_psf], axis=-1))
            else:
                x, encoder_layers = self.encoder(x)

        # Bottleneck.
        with tf.variable_scope('encoder_module'):
            b, b_loss = self.bottleneck(x)

        xb_loss = 0.0
        b_shape = common_layers.shape_list(b)
        self._cur_bottleneck_tensor = b
        res_size = common_layers.shape_list(x)[-1]
        with tf.variable_scope('decoder_module'):
            b = self.unbottleneck(b, res_size)
        if not is_training:
            x = b
        else:
            l = 2**hparams.num_hidden_layers
            warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
            nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
            if common_layers.should_generate_summaries():
                tf.summary.scalar("nomix_p_bottleneck", nomix_p)
            rand = tf.random_uniform(common_layers.shape_list(x))
            # This is the distance between b and x. Having this as loss helps learn
            # the bottleneck function, but if we back-propagated to x it would be
            # minimized by just setting x=0 and b=0 -- so we don't want too much
            # of the influence of this, and we stop-gradient to not zero-out x.
            x_stop = tf.stop_gradient(x)
            xb_loss = tf.reduce_mean(
                tf.reduce_sum(tf.squared_difference(x_stop, b), axis=-1))
            # To prevent this loss from exploding we clip at 1, but anneal clipping.
            clip_max = 1.0 / common_layers.inverse_exp_decay(warm_step,
                                                             min_value=0.001)
            xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
            xb_loss *= clip_max / xb_clip
            x = tf.where(tf.less(rand, nomix_p), b, x)
    else:
        if self._cur_bottleneck_tensor is None:
            b = self.sample()
        else:
            b = self._cur_bottleneck_tensor
        self._cur_bottleneck_tensor = b
        res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
        res_size = min(res_size, hparams.max_hidden_size)

        with tf.variable_scope('decoder_module'):
            x = self.unbottleneck(b, res_size)

    # Run decoder.
    with tf.variable_scope('decoder_module'):
        x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        res = x[:, :shape[1], :shape[2], :]

    with tf.variable_scope('decoder_module'):
        reconstr = tf.layers.dense(res,
                                   self.num_channels,
                                   name="autoencoder_final",
                                   activation=output_activation)

    # Apply channel-wise convolution with the PSF if requested
    # TODO: Handle multiple bands
    if hparams.apply_psf and 'psf' in features:
        if self.num_channels > 1:
            raise NotImplementedError
        rec_padded = tf.pad(
            reconstr[:, :, :, 0],
            [[0, 0], [0, int(hparams.psf_convolution_pad_factor * shape[1])],
             [0, int(hparams.psf_convolution_pad_factor * shape[2])]])
        psf_padded = tf.pad(
            features['psf'][..., 0],
            [[0, 0], [0, int(hparams.psf_convolution_pad_factor * shape[1])],
             [0, int(hparams.psf_convolution_pad_factor * shape[2])]])
        reconstr = tf.expand_dims(tf.spectral.irfft2d(
            tf.spectral.rfft2d(rec_padded) *
            tf.cast(tf.abs(tf.spectral.rfft2d(psf_padded)), tf.complex64)),
                                  axis=-1)
        reconstr = reconstr[:, :shape[1], :shape[2], :]

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    loglik = loglikelihood_fn(labels, reconstr, features, hparams)
    targets_loss = tf.reduce_mean(-loglik)

    tf.summary.scalar("negloglik", targets_loss)
    tf.summary.scalar("bottleneck_loss", b_loss)

    losses["training"] = targets_loss
    logits = tf.reshape(reconstr, labels_shape)

    image_summary("ae", reconstr)
    image_summary("input", labels)

    return logits, losses