Ejemplo n.º 1
0
  def body(self, features):
    """Body of the model.

    Args:
      features: a dictionary with the tensors.

    Returns:
      A pair (predictions, losses) where predictions is the generated image
      and losses is a dictionary of losses (that get added for the final loss).
    """
    features["targets"] = features["inputs"]
    is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN

    # Input images.
    inputs = tf.to_float(features["targets_raw"])

    # Noise vector.
    z = tf.random_uniform([self.hparams.batch_size,
                           self.hparams.bottleneck_bits],
                          minval=-1, maxval=1, name="z")

    # Generator output: fake images.
    out_shape = common_layers.shape_list(inputs)[1:4]
    g = self.generator(z, is_training, out_shape)

    losses = self.losses(inputs, g)  # pylint: disable=not-callable

    summary_g_image = tf.reshape(
        g[0, :], [1] + common_layers.shape_list(inputs)[1:])
    tf.summary.image("generated", summary_g_image, max_outputs=1)

    if is_training:  # Returns an dummy output and the losses dictionary.
      return tf.zeros_like(inputs), losses
    return tf.reshape(g, tf.shape(inputs)), losses
Ejemplo n.º 2
0
  def infer(self, features, *args, **kwargs):
    """Produce predictions from the model by running it."""
    del args, kwargs
    if "targets" not in features:
      if "infer_targets" in features:
        targets_shape = common_layers.shape_list(features["infer_targets"])
      elif "inputs" in features:
        targets_shape = common_layers.shape_list(features["inputs"])
        targets_shape[1] = self.hparams.video_num_target_frames
      else:
        raise ValueError("no inputs are given.")
      features["targets"] = tf.zeros(targets_shape, dtype=tf.float32)

    output, _ = self(features)  # pylint: disable=not-callable

    if not isinstance(output, dict):
      output = {"targets": output}

    x = output["targets"]
    if self.is_per_pixel_softmax:
      x_shape = common_layers.shape_list(x)
      x = tf.reshape(x, [-1, x_shape[-1]])
      x = tf.argmax(x, axis=-1)
      x = tf.reshape(x, x_shape[:-1])
    else:
      x = tf.squeeze(x, axis=-1)
      x = tf.to_int64(tf.round(x))
    output["targets"] = x
    if self.hparams.reward_prediction:
      output["target_reward"] = tf.argmax(output["target_reward"], axis=-1)

    # only required for decoding.
    output["outputs"] = output["targets"]
    output["scores"] = output["targets"]
    return output
Ejemplo n.º 3
0
  def infer(self,
            features=None,
            decode_length=50,
            beam_size=1,
            top_beams=1,
            alpha=0.0,
            use_tpu=False):
    """Produce predictions from the model."""
    if not features:
      features = {}
    inputs_old = None
    if "inputs" in features and len(features["inputs"].shape) < 4:
      inputs_old = features["inputs"]
      features["inputs"] = tf.expand_dims(features["inputs"], 2)

    # Create an initial targets tensor.
    if "partial_targets" in features:
      initial_output = tf.convert_to_tensor(features["partial_targets"])
    else:
      batch_size = common_layers.shape_list(features["inputs"])[0]
      length = common_layers.shape_list(features["inputs"])[1]
      target_length = tf.to_int32(2.0 * tf.to_float(length))
      initial_output = tf.zeros((batch_size, target_length, 1, 1),
                                dtype=tf.int64)

    features["targets"] = initial_output
    logits, _ = self(features)  # pylint: disable=not-callable
    samples = tf.argmax(logits, axis=-1)
    if inputs_old is not None:  # Restore to not confuse Estimator.
      features["inputs"] = inputs_old
    return samples
Ejemplo n.º 4
0
 def bottleneck(self, x):
   hparams = self.hparams
   b, _ = super(AutoencoderDualDiscrete, self).bottleneck(x)
   if hparams.mode == tf.estimator.ModeKeys.EVAL:
     return b, 0.0
   bt, bi = tf.split(b, 2, axis=0)
   if self.hparams.mode != tf.estimator.ModeKeys.TRAIN:
     return tf.concat([bi, bi], axis=0), 0.0
   # Share the first hparams.bottleneck_shared_bits.
   shared = (bt + bi) / 2  # -1 if both -1, 1 if both were 1, 0 if disagree.
   rand = tf.random_uniform(common_layers.shape_list(bt))
   br = tf.where(rand < 0.5, bt, bi)  # Break ties at random.
   bs = tf.where(shared == 0, br, shared)
   bs = tf.concat([bs, bs], axis=0)
   n = hparams.bottleneck_shared_bits
   step = tf.train.get_global_step()
   zero = tf.constant(0, dtype=tf.int64)
   if step is None:
     step = zero
   step = tf.maximum(zero, step - hparams.bottleneck_shared_bits_start_warmup)
   f = common_layers.inverse_lin_decay(
       hparams.bottleneck_shared_bits_stop_warmup, min_value=0.1, step=step)
   n = tf.where(step > 1, n * f, n)
   n = tf.cast(n, tf.int64)
   b_shape = common_layers.shape_list(b)
   b = tf.concat([bs[..., :n], b[..., n:]], axis=-1)
   b = tf.reshape(b, b_shape)
   return b, 0.0
Ejemplo n.º 5
0
def padded_sequence_accuracy(predictions,
                             labels,
                             weights_fn=common_layers.weights_nonzero):
  """Percentage of times that predictions matches labels everywhere (non-0)."""
  # If the last dimension is 1 then we're using L1/L2 loss.
  if common_layers.shape_list(predictions)[-1] == 1:
    return rounding_sequence_accuracy(
        predictions, labels, weights_fn=weights_fn)
  with tf.variable_scope(
      "padded_sequence_accuracy", values=[predictions, labels]):
    padded_predictions, padded_labels = common_layers.pad_with_zeros(
        predictions, labels)
    weights = weights_fn(padded_labels)

    # Flatten, keeping batch dim (and num_classes dim for predictions)
    # TPU argmax can only deal with a limited number of dimensions
    predictions_shape = common_layers.shape_list(padded_predictions)
    batch_size = predictions_shape[0]
    num_classes = predictions_shape[-1]
    flat_size = common_layers.list_product(
        common_layers.shape_list(padded_labels)[1:])
    padded_predictions = tf.reshape(
        padded_predictions,
        [batch_size, common_layers.list_product(predictions_shape[1:-1]),
         num_classes])
    padded_labels = tf.reshape(padded_labels, [batch_size, flat_size])
    weights = tf.reshape(weights, [batch_size, flat_size])

    outputs = tf.to_int32(tf.argmax(padded_predictions, axis=-1))
    padded_labels = tf.to_int32(padded_labels)
    not_correct = tf.to_float(tf.not_equal(outputs, padded_labels)) * weights
    axis = list(range(1, len(outputs.get_shape())))
    correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
    return correct_seq, tf.constant(1.0)
Ejemplo n.º 6
0
def dae(x, hparams, name):
  with tf.variable_scope(name):
    m = tf.layers.dense(x, hparams.v_size, name="mask")
    if hparams.softmax_k > 0:
      m, kl = top_k_softmax(m, hparams.softmax_k)
      return m, m, 1.0 - tf.reduce_mean(kl)
    logsm = tf.nn.log_softmax(m)
    # Gumbel-softmax sample.
    gumbel_samples = gumbel_sample(common_layers.shape_list(m))
    steps = hparams.kl_warmup_steps
    gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5
    temperature = 1.2 - common_layers.inverse_lin_decay(steps)
    # 10% of the time keep reasonably high temperature to keep learning.
    temperature = tf.cond(tf.less(tf.random_uniform([]), 0.9),
                          lambda: temperature,
                          lambda: tf.random_uniform([], minval=0.5, maxval=1.0))
    s = tf.nn.softmax((logsm + gumbel_samples) / temperature)
    m = tf.nn.softmax(m)
    kl = - tf.reduce_max(logsm, axis=-1)
    if _DO_SUMMARIES:
      tf.summary.histogram("max-log", tf.reshape(kl, [-1]))
    # Calculate the argmax and construct hot vectors.
    maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1])
    maxvhot = tf.stop_gradient(tf.one_hot(maxvec, hparams.v_size))
    # Add losses that prevent too few being used.
    distrib = tf.reshape(logsm, [-1, hparams.v_size]) * maxvhot
    d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True)
    d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0])
    d_dev = - tf.reduce_mean(d_variance)
    ret = s
    if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN:
      ret = tf.reshape(maxvhot, common_layers.shape_list(s))  # Just hot @eval.
    return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
Ejemplo n.º 7
0
  def embed(self, x):
    """Embedding function that takes discrete latent and returns embedding.

    Args:
        x: Input to the discretization bottleneck.
    Returns:
        Continuous embedding to be passed on to the decoder.

    Raises:
        ValueError: For unknown or missing arguments.
    """
    shape_x = common_layers.shape_list(x)
    x_flat = tf.reshape(x, [-1, 1])
    c = self.int_to_bit(x_flat, num_bits=self.hparams.z_size, base=2)
    shape = common_layers.shape_list(c)
    new_shape = shape
    new_shape.append(self.hparams.num_blocks)
    new_shape.append(int(self.hparams.z_size / self.hparams.num_blocks))
    c = tf.to_int32(tf.reshape(c, shape=new_shape))
    h1_shape = shape_x
    h1_shape.append(self.hparams.hidden_size)
    h1 = tf.zeros(dtype=tf.float32, shape=h1_shape)
    c_int = self.bit_to_int(
        c, num_bits=int(self.hparams.z_size / self.hparams.num_blocks), base=2)
    c_hot = tf.one_hot(c_int, depth=self.hparams.block_v_size, axis=-1)
    c_hot_flat = tf.reshape(
        c_hot, shape=[-1, self.hparams.num_blocks, self.hparams.block_v_size])
    h1 = tf.matmul(tf.transpose(c_hot_flat, perm=[1, 0, 2]), self.means)
    h1 = tf.transpose(h1, perm=[1, 0, 2])
    h1 = tf.reshape(h1, shape=h1_shape)
    h1_shape[0] = self.hparams.batch_size
    h2 = tf.layers.dense(tf.nn.relu(h1), self.hparams.filter_size, name="vch2")
    res = tf.layers.dense(
        tf.nn.relu(h2), self.hparams.hidden_size, name="vcfin")
    return res
Ejemplo n.º 8
0
    def symbols_to_logits_fn(ids):
      """Go from ids to logits."""
      ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]])
      if "partial_targets" in features:
        pt = features["partial_targets"]
        pt_length = common_layers.shape_list(pt)[1]
        pt = tf.tile(pt, [1, beam_size])
        pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1])
        ids = tf.concat([pt, ids], axis=1)

      features["targets"] = ids
      self._coverage = None
      logits, _ = self(features)  # pylint: disable=not-callable
      # now self._coverage is a coverage tensor for the first datashard.
      # it has shape [batch_size] and contains floats between 0 and
      # source_length.
      if self._problem_hparams:
        modality = self._problem_hparams.target_modality
        if modality.top_is_pointwise:
          return tf.squeeze(logits, axis=[1, 2, 3])
      # -1 due to the pad above.
      current_output_position = common_layers.shape_list(ids)[1] - 1
      logits = logits[:, current_output_position, :, :]
      return tf.squeeze(logits, axis=[1, 2])
def postprocess_image(x, rows, cols, hparams):
  """Postprocessing after decoding."""
  batch = common_layers.shape_list(x)[0]
  channels = 256
  x = tf.reshape(x, [batch, rows, cols, hparams.hidden_size])
  # targets = common_layers.conv(x, 256, (1, 1), name="output_conv")
  targets = tf.layers.dense(x, 256, use_bias=True, activation=None,
                            name="output_conv")
  if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
    y = targets
    y = tf.reshape(y, [batch, -1, hparams.img_len*3, channels])
    yshape = common_layers.shape_list(y)
    block_length = hparams.query_shape[0]
    block_width = hparams.query_shape[1]

    # Break into block row wise.
    y = tf.reshape(y,
                   [batch, yshape[1] // block_length,
                    block_length,
                    yshape[2], channels])
    yshape = common_layers.shape_list(y)
    # Break into blocks width wise.
    y_blocks = tf.reshape(y,
                          [batch, yshape[1], yshape[2],
                           yshape[3] // block_width,
                           block_width, channels])

    # Reshape targets as [batch_size, num_blocks_rows, num_block_cols,
    # block_length, block_width, channels]
    targets = tf.transpose(y_blocks, [0, 1, 3, 2, 4, 5])

  return targets
Ejemplo n.º 10
0
def transformer_prepare_decoder(targets, hparams, features=None):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(
          common_layers.shape_list(targets)[1]))
  if features and "targets_segmentation" in features:
    # "Packed" dataset - keep the examples from seeing each other.
    targets_segmentation = features["targets_segmentation"]
    targets_position = features["targets_position"]
    decoder_self_attention_bias += common_attention.attention_bias_same_segment(
        targets_segmentation, targets_segmentation)
  else:
    targets_position = None
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    if targets_position is not None:
      decoder_input = common_attention.add_timing_signal_1d_given_position(
          decoder_input, targets_position)
    else:
      decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
Ejemplo n.º 11
0
def create_output(decoder_output, rows, cols, targets, hparams):
  """Creates output from decoder output and vars.

  Args:
    decoder_output: Tensor of shape [batch, ...], where ... can be any rank such
      that the number of elements is batch * rows * cols * hparams.hidden_size.
    rows: Integer representing number of rows in a 2-D data point.
    cols: Integer representing number of columns in a 2-D data point.
    targets: Tensor of shape [batch, hparams.img_len, hparams.img_len,
      hparams.num_channels].
    hparams: tf.contrib.training.HParams set.

  Returns:
    Tensor of shape [batch, hparams.img_len, hparams.img_len,
    hparams.num_mixtures * 10] if hparams.likelihood is DMOL, otherwise
    [batch, hparams.img_len, hparams.img_len, hparams.num_channels, 256].
    In the special case of predict mode, it is a Tensor of rank 5.
  """
  decoded_image = postprocess_image(decoder_output, rows, cols, hparams)
  depth = common_layers.shape_list(decoded_image)[-1]
  batch, height, width, channels = common_layers.shape_list(targets)
  likelihood = getattr(hparams, "likelihood", DistributionType.CAT)
  if hparams.mode == tf.estimator.ModeKeys.PREDICT:
    y = tf.reshape(decoded_image, [batch, -1, 1, 1, depth])
    output = y[:, :height, :, :, :]
  elif likelihood == DistributionType.CAT:
    # Unpack the cols dimension of the Categorical.
    output = tf.reshape(decoded_image,
                        [batch, height, width, channels, depth])
  else:
    output = decoded_image
  return output
Ejemplo n.º 12
0
def vq_discrete_unbottleneck(x, hidden_size):
  """Simple undiscretization from vector quantized representation."""
  x_shape = common_layers.shape_list(x)
  x = tf.to_float(x)
  bottleneck_size = common_layers.shape_list(x)[-1]
  means, _, _ = get_vq_bottleneck(bottleneck_size, hidden_size)
  result = tf.matmul(tf.reshape(x, [-1, x_shape[-1]]), means)
  return tf.reshape(result, x_shape[:-1] + [hidden_size])
Ejemplo n.º 13
0
def prepare_decoder(targets, hparams):
  """Prepare decoder for images."""
  targets_shape = common_layers.shape_list(targets)
  channels = hparams.num_channels
  curr_infer_length = None

  # during training, images are [batch, IMG_LEN, IMG_LEN, 3].
  # At inference, they are [batch, curr_infer_length, 1, 1]
  if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
    curr_infer_length = targets_shape[1]
    if hparams.block_raster_scan:
      assert hparams.img_len*channels % hparams.query_shape[1] == 0
      assert hparams.img_len % hparams.query_shape[0] == 0
      total_block_width = hparams.img_len*channels
      # Decoding is in block raster scan order. We divide the image into
      # hparams.query_shape blocks and then decode each block in raster scan.
      # To make that compatible with our inference pipeline, pad the target so
      # that rows is a multiple of query_shape and columns is a multiple of
      # hparams.img_len*channels
      curr_infer_length = targets_shape[1]
      block_padding_factor = total_block_width * hparams.query_shape[0]
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % block_padding_factor],
          [0, 0], [0, 0]])

      num_blocks = total_block_width // hparams.query_shape[1]
      # Reshape the image to represent blocks
      target_blocks = tf.reshape(
          targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0],
                    hparams.query_shape[1]])
      # Transpose to read the image in 2D fashion.
      targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
    else:
      # add padding to make sure the size of targets is a multiple of img_height
      # times number of channels. This is  needed for positional encodings and
      # for doing the RGB lookup.
      padding_factor = channels * hparams.img_len
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]])
    targets = tf.reshape(targets,
                         [targets_shape[0], -1, hparams.img_len, channels])
  # Preprocess image
  x = prepare_image(targets, hparams, name="dec_channels")
  x_shape = common_layers.shape_list(x)
  if (hparams.dec_attention_type == AttentionType.LOCAL_2D or
      hparams.dec_attention_type == AttentionType.LOCAL_BLOCK):
    x = common_attention.right_shift_blockwise(x, hparams.query_shape)
    x = add_pos_signals(x, hparams, "dec_pos")
  else:
    # Add position signals
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1]*x_shape[2], hparams.hidden_size])
    x = common_layers.shift_right_3d(x)
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1], x_shape[2], hparams.hidden_size])
    x = add_pos_signals(x, hparams, "dec_pos")
  x = common_layers.cast_like(x, targets)
  return x, x_shape[1], x_shape[2]
Ejemplo n.º 14
0
 def logits_to_samples(logits):
   """Get samples from logits."""
   # If the last dimension is 1 then we're using L1/L2 loss.
   if common_layers.shape_list(logits)[-1] == 1:
     return tf.to_int32(tf.squeeze(logits, axis=-1))
   # Argmax in TF doesn't handle more than 5 dimensions yet.
   logits_shape = common_layers.shape_list(logits)
   argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=-1)
   return tf.reshape(argmax, logits_shape[:-1])
Ejemplo n.º 15
0
 def loss(self, top_out, targets):
   predictions = top_out
   if (len(common_layers.shape_list(top_out)) != len(
       common_layers.shape_list(targets))):
     predictions = tf.squeeze(top_out, axis=[-1])
   with tf.name_scope("l2"):
     weights = self.targets_weights_fn(targets)
     l2 = tf.pow(predictions - targets, 2)
     return tf.reduce_sum(l2 * weights), tf.reduce_sum(weights)
Ejemplo n.º 16
0
 def loss(self, top_out, targets):
   predictions = top_out
   if (len(common_layers.shape_list(top_out)) != len(
       common_layers.shape_list(targets))):
     predictions = tf.squeeze(top_out, axis=[-1])
   with tf.name_scope("log_possion"):
     weights = self.targets_weights_fn(targets)
     lp_loss = tf.nn.log_poisson_loss(targets, predictions)
     return tf.reduce_sum(lp_loss * weights), tf.reduce_sum(weights)
Ejemplo n.º 17
0
  def construct_model(self, images, actions, rewards):
    images = tf.unstack(images, axis=0)
    actions = tf.unstack(actions, axis=0)
    rewards = tf.unstack(rewards, axis=0)

    batch_size = common_layers.shape_list(images[0])[0]
    context_frames = self.hparams.video_num_input_frames

    # Predicted images and rewards.
    gen_rewards, gen_images, latent_means, latent_stds = [], [], [], []

    # LSTM states.
    lstm_state = [None] * 7

    # Create scheduled sampling function
    ss_func = self.get_scheduled_sample_func(batch_size)

    pred_image = tf.zeros_like(images[0])
    pred_reward = tf.zeros_like(rewards[0])
    latent = None
    for timestep, image, action, reward in zip(
        range(len(images)-1), images[:-1], actions[:-1], rewards[:-1]):
      # Scheduled Sampling
      done_warm_start = timestep > context_frames - 1
      groundtruth_items = [image, reward]
      generated_items = [pred_image, pred_reward]
      input_image, input_reward = self.get_scheduled_sample_inputs(
          done_warm_start, groundtruth_items, generated_items, ss_func)

      # Latent
      # TODO(mbz): should we use input_image iunstead of image?
      latent_images = tf.stack([image, images[timestep+1]], axis=0)
      latent_mean, latent_std = self.construct_latent_tower(
          latent_images, time_axis=0)
      latent = common_video.get_gaussian_tensor(latent_mean, latent_std)
      latent_means.append(latent_mean)
      latent_stds.append(latent_std)

      # Prediction
      pred_image, lstm_state, _ = self.construct_predictive_tower(
          input_image, input_reward, action, lstm_state, latent)

      if self.hparams.reward_prediction:
        pred_reward = self.reward_prediction(
            pred_image, input_reward, action, latent)
        pred_reward = common_video.decode_to_shape(
            pred_reward, common_layers.shape_list(input_reward), "reward_dec")
      else:
        pred_reward = input_reward

      gen_images.append(pred_image)
      gen_rewards.append(pred_reward)

    gen_images = tf.stack(gen_images, axis=0)
    gen_rewards = tf.stack(gen_rewards, axis=0)

    return gen_images, gen_rewards, latent_means, latent_stds
Ejemplo n.º 18
0
def postprocess_image(x, rows, cols, hparams):
  """Postprocessing after decoding.

  Args:
    x: Tensor of shape [batch, ...], where ... can be any rank such that the
      number of elements in x is batch * rows * cols * hparams.hidden_size.
    rows: Integer representing number of rows in a 2-D data point.
    cols: Integer representing number of columns in a 2-D data point.
    hparams: tf.contrib.training.HParams set.

  Returns:
    Tensor of shape [batch, rows, cols, depth], where depth is
    hparams.num_mixtures * 10 if hparams.likelihood is DMOL, otherwise 256. In
    the special case of inference and block raster scan order, it is a Tensor
    of shape [batch, num_blocks_rows, num_block_cols, block_length, block_width,
    depth].
  """
  batch = common_layers.shape_list(x)[0]
  x = tf.reshape(x, [batch, rows, cols, hparams.hidden_size])
  likelihood = getattr(hparams, "likelihood", DistributionType.CAT)
  if likelihood == DistributionType.DMOL:
    depth = hparams.num_mixtures * 10
    targets = tf.layers.dense(x,
                              depth,
                              use_bias=False,
                              activation=None,
                              name="output_conv")
  else:
    depth = 256
    targets = tf.layers.dense(x,
                              depth,
                              use_bias=True,
                              activation=None,
                              name="output_conv")
  if (hparams.mode == tf.contrib.learn.ModeKeys.INFER and
      hparams.block_raster_scan):
    y = targets
    yshape = common_layers.shape_list(y)
    block_length = hparams.query_shape[0]
    block_width = hparams.query_shape[1]

    # Break into block row wise.
    y = tf.reshape(y,
                   [batch, yshape[1] // block_length, block_length,
                    yshape[2], depth])
    yshape = common_layers.shape_list(y)
    # Break into blocks width wise.
    y_blocks = tf.reshape(y,
                          [batch, yshape[1], yshape[2],
                           yshape[3] // block_width, block_width, depth])

    # Reshape targets as [batch, num_blocks_rows, num_block_cols, block_length,
    # block_width, depth].
    targets = tf.transpose(y_blocks, [0, 1, 3, 2, 4, 5])

  return targets
Ejemplo n.º 19
0
def top_k_experts(x, k, hparams):
  x_shape = common_layers.shape_list(x)
  x_flat = tf.reshape(x, [-1, common_layers.shape_list(x)[-1]])
  is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
  gates, load = expert_utils.noisy_top_k_gating(
      x_flat, 2 ** hparams.z_size, is_training, k)
  gates_shape = [x_shape[0], x_shape[1], x_shape[2], 2 ** hparams.z_size]
  gates = tf.reshape(gates, gates_shape)
  load_loss = expert_utils.cv_squared(load)
  return gates, load_loss
Ejemplo n.º 20
0
  def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
            alpha=0.0, use_tpu=False):
    """Produce predictions from the model."""
    if not self._hparams.do_mask:
      infer_out = super(TransformerAE, self).infer(
          features, decode_length, beam_size, top_beams, alpha, use_tpu=use_tpu)
      return infer_out["outputs"]
    if not features:
      features = {}
    inputs_old = None
    if "inputs" in features and len(features["inputs"].shape) < 4:
      inputs_old = features["inputs"]
      features["inputs"] = tf.expand_dims(features["inputs"], 2)

    # Create an initial targets tensor.
    if "partial_targets" in features:
      initial_output = tf.convert_to_tensor(features["partial_targets"])
    else:
      # inputs might not be present in features (e.g.: language modeling),
      # in which case we fallback to 'infer_targets' for calculating initial
      # input shape, type, etc.
      inputs_or_targets = features.get("inputs", features.get("infer_targets"))
      batch_size = common_layers.shape_list(inputs_or_targets)[0]
      length = common_layers.shape_list(inputs_or_targets)[1]
      hidden_dim = common_layers.shape_list(inputs_or_targets)[-1]
      target_length = tf.to_int32(2.0 * tf.to_float(length))
      initial_output = tf.zeros((batch_size, target_length, 1, hidden_dim),
                                dtype=inputs_or_targets.dtype)

    features["targets"] = initial_output
    logits, _ = self(features)  # pylint: disable=not-callable
    # this should only happen if we're doing target_modality not real
    if inputs_or_targets.dtype == tf.float32:
      samples = logits
    else:
      samples = tf.argmax(logits, axis=-1)

    # More steps.
    self.predict_mask = 0.0  # Use the provided targets this time.
    how_many_more_steps = 0  # Set to 1 or more for Gibbs-like sampling.
    for _ in range(how_many_more_steps):
      with tf.variable_scope(tf.get_variable_scope(), reuse=True):
        features["targets"] = samples
        logits, _ = self(features)  # pylint: disable=not-callable
        if inputs_or_targets.dtype == tf.float32:
          # When target_modality is real, the last axis does not represent
          # classes, so it should not be argmax'ed
          samples = logits
        else:
          samples = tf.argmax(logits, axis=-1)

    self.predict_mask = 1.0
    if inputs_old is not None:  # Restore to not confuse Estimator.
      features["inputs"] = inputs_old
    return samples
Ejemplo n.º 21
0
def decode_transformer(encoder_output,
                       encoder_decoder_attention_bias,
                       targets,
                       hparams,
                       name,
                       task=None):
  """Original Transformer decoder."""
  with tf.variable_scope(name):
    if task is None:
      task = hparams.task
    if task == "translate":
      targets = common_layers.flatten4d3d(targets)

      decoder_input, decoder_self_bias = (
          transformer.transformer_prepare_decoder(targets, hparams))

      decoder_input = tf.nn.dropout(decoder_input,
                                    1.0 - hparams.layer_prepostprocess_dropout)

      decoder_output = transformer.transformer_decoder(
          decoder_input,
          encoder_output,
          decoder_self_bias,
          encoder_decoder_attention_bias,
          hparams)
      decoder_output = tf.expand_dims(decoder_output, axis=2)
    else:
      assert task == "image"
      inputs = None
      # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise
      # prepare_image will choke
      targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len,
                                     hparams.img_len,
                                     hparams.num_channels*hparams.hidden_size])

      # Prepare decoder inputs and bias.
      decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams)
      # Add class label to decoder input.
      if not hparams.drop_inputs:
        decoder_input += tf.reshape(
            inputs,
            [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size])
      decoder_output = cia.transformer_decoder_layers(
          decoder_input,
          None,
          bias,
          hparams.num_decoder_layers or hparams.num_hidden_layers,
          hparams,
          attention_type=hparams.dec_attention_type,
          name="decoder")
    decoder_output_shape = common_layers.shape_list(decoder_output)
    decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1,
                                                 hparams.hidden_size])
    # Expand since t2t expects 4d tensors.
    return decoder_output
Ejemplo n.º 22
0
 def loss(self, top_out, targets):
   """Compute loss numerator and denominator for one shard of output."""
   logits = top_out
   logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:])
   targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
   cutoff = getattr(self._model_hparams, "video_modality_loss_cutoff", 0.01)
   return common_layers.padded_cross_entropy(
       logits,
       targets,
       self._model_hparams.label_smoothing,
       cutoff=cutoff,
       weights_fn=self.targets_weights_fn)
Ejemplo n.º 23
0
 def reduce_dimensions(predictions, labels):
   """Reduce dimensions for high-dimensional predictions and labels."""
   # We will treat first dimensions as batch. One example are video frames.
   if len(predictions.get_shape()) > 5:
     predictions_shape = common_layers.shape_list(predictions)
     predictions = tf.reshape(
         predictions, [predictions_shape[0], predictions_shape[1], -1,
                       predictions_shape[-1]])
     labels_shape = common_layers.shape_list(labels)
     labels = tf.reshape(
         labels, [labels_shape[0], labels_shape[1], -1])
   return predictions, labels
Ejemplo n.º 24
0
 def logits_to_samples(logits, key):
   """Get samples from logits."""
   # If the last dimension is 1 then we're using L1/L2 loss.
   if common_layers.shape_list(logits)[-1] == 1:
     return tf.to_int32(tf.squeeze(logits, axis=-1))
   if key == "targets":
     return pixels_from_softmax(
         logits, gumbel_noise_factor=0.0,
         temperature=hparams.pixel_sampling_temperature)
   # Argmax in TF doesn't handle more than 5 dimensions yet.
   logits_shape = common_layers.shape_list(logits)
   argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=-1)
   return tf.reshape(argmax, logits_shape[:-1])
Ejemplo n.º 25
0
 def loss(self, top_out, targets):
   """Compute loss numerator and denominator for one shard of output."""
   logits = top_out
   logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:-1])
   targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:])
   weights = self.targets_weights_fn(targets)
   # Shift targets by 0.5 so later just casting to int gives the prediction.
   # So for int targets, say 0 and 7, we actually train to predict 0.5 and 7.5.
   # Later (in merics or infer) this is cast to int anyway. Also, we have no
   # loss beyond self.cutoff = 0.2 as these are already correct predictions.
   targets = tf.to_float(targets) + 0.5
   loss = self.internal_loss(logits, targets)
   return tf.reduce_sum(loss * weights), tf.reduce_sum(weights)
Ejemplo n.º 26
0
  def discriminator(self, x, is_training):
    """Discriminator architecture based on InfoGAN.

    Args:
      x: input images, shape [bs, h, w, channels]
      is_training: boolean, are we in train or eval model.

    Returns:
      out_logit: the output logits (before sigmoid).
    """
    hparams = self.hparams
    with tf.variable_scope(
        "discriminator", initializer=tf.random_normal_initializer(stddev=0.02)):
      batch_size, height, width = common_layers.shape_list(x)[:3]
      # Mapping x from [bs, h, w, c] to [bs, 1]
      net = tf.layers.conv2d(
          x, 64, (4, 4), strides=(2, 2), padding="SAME", name="d_conv1")
      # [bs, h/2, w/2, 64]
      net = lrelu(net)
      net = tf.layers.conv2d(
          net, 128, (4, 4), strides=(2, 2), padding="SAME", name="d_conv2")
      # [bs, h/4, w/4, 128]
      if hparams.discriminator_batchnorm:
        net = tf.layers.batch_normalization(
            net, training=is_training, momentum=0.999, name="d_bn2")
      net = lrelu(net)
      size = height * width
      net = tf.reshape(net, [batch_size, size * 8])  # [bs, h * w * 8]
      net = tf.layers.dense(net, 1024, name="d_fc3")  # [bs, 1024]
      if hparams.discriminator_batchnorm:
        net = tf.layers.batch_normalization(
            net, training=is_training, momentum=0.999, name="d_bn3")
      net = lrelu(net)
      return net
Ejemplo n.º 27
0
    def process_single_frame(prev_outputs, inputs):
      """Process a single frame of the video."""
      cur_image, input_reward, action = inputs
      time_step, prev_image, prev_reward, frame_buf, lstm_states = prev_outputs

      # sample from softmax (by argmax). this is noop for non-softmax loss.
      prev_image = self.get_sampled_frame(prev_image)

      generated_items = [prev_image]
      groundtruth_items = [cur_image]
      done_warm_start = tf.greater(time_step, context_frames - 1)
      input_image, = self.get_scheduled_sample_inputs(
          done_warm_start, groundtruth_items, generated_items, ss_func)

      # Prediction
      pred_image, lstm_states, _ = self.construct_predictive_tower(
          input_image, None, action, lstm_states, latent)

      if self.hparams.reward_prediction:
        reward_input_image = self.get_sampled_frame(pred_image)
        if self.hparams.reward_prediction_stop_gradient:
          reward_input_image = tf.stop_gradient(reward_input_image)
        with tf.control_dependencies([time_step]):
          frame_buf = [reward_input_image] + frame_buf[:-1]
        pred_reward = self.reward_prediction(frame_buf, None, action, latent)
        pred_reward = common_video.decode_to_shape(
            pred_reward, common_layers.shape_list(input_reward), "reward_dec")
      else:
        pred_reward = prev_reward

      time_step += 1
      outputs = (time_step, pred_image, pred_reward, frame_buf, lstm_states)

      return outputs
Ejemplo n.º 28
0
  def bottom_compress(self, inputs, name="bottom"):
    """Transform input from data space to model space.

    Perform conversion of RGB pixel values to a real number and combine values
    for each pixel to form representation of image_length x image_length dims.

    Args:
      inputs: A Tensor with shape [batch, ...]
      name: string, scope.
    Returns:
      body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
    """
    with tf.variable_scope(name):
      inputs = common_layers.convert_rgb_to_real(inputs)
      ishape = common_layers.shape_list(inputs)
      inputs = tf.reshape(inputs, [-1, ishape[1], ishape[2] * ishape[3], 1])
      inputs.set_shape([None, None, None, 1])
      # We compress RGB intensities for each pixel using a conv.
      x = common_layers.conv_block(
          inputs,
          self._body_input_depth, [((1, 1), (1, 3))],
          first_relu=False,
          padding="VALID",
          strides=(1, 3),
          force2d=True,
          name="conv_input")
      return x
Ejemplo n.º 29
0
def lstm_cell(inputs,
              state,
              num_units,
              use_peepholes=False,
              cell_clip=0.0,
              initializer=None,
              num_proj=None,
              num_unit_shards=None,
              num_proj_shards=None,
              reuse=None,
              name=None):
  """Full LSTM cell."""
  input_shape = common_layers.shape_list(inputs)
  cell = tf.contrib.rnn.LSTMCell(num_units,
                                 use_peepholes=use_peepholes,
                                 cell_clip=cell_clip,
                                 initializer=initializer,
                                 num_proj=num_proj,
                                 num_unit_shards=num_unit_shards,
                                 num_proj_shards=num_proj_shards,
                                 reuse=reuse,
                                 name=name,
                                 state_is_tuple=False)
  if state is None:
    state = cell.zero_state(input_shape[0], tf.float32)
  outputs, new_state = cell(inputs, state)
  return outputs, new_state
Ejemplo n.º 30
0
def scheduled_sample_count(ground_truth_x,
                           generated_x,
                           batch_size,
                           scheduled_sample_var):
  """Sample batch with specified mix of groundtruth and generated data points.

  Args:
    ground_truth_x: tensor of ground-truth data points.
    generated_x: tensor of generated data points.
    batch_size: batch size
    scheduled_sample_var: number of ground-truth examples to include in batch.
  Returns:
    New batch with num_ground_truth sampled from ground_truth_x and the rest
    from generated_x.
  """
  num_ground_truth = scheduled_sample_var
  idx = tf.random_shuffle(tf.range(batch_size))
  ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
  generated_idx = tf.gather(idx, tf.range(num_ground_truth, batch_size))

  ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
  generated_examps = tf.gather(generated_x, generated_idx)

  output = tf.dynamic_stitch([ground_truth_idx, generated_idx],
                             [ground_truth_examps, generated_examps])
  # if batch size is known set it.
  if isinstance(batch_size, int):
    output.set_shape([batch_size] + common_layers.shape_list(output)[1:])
  return output
Ejemplo n.º 31
0
    def body_single(self, features):
        hparams = self.hparams
        filters = hparams.hidden_size
        kernel1, kernel2 = (3, 3), (4, 4)

        # Embed the inputs.
        inputs_shape = common_layers.shape_list(features["inputs"])
        # Using non-zero bias initializer below for edge cases of uniform inputs.
        x = tf.layers.dense(
            features["inputs"],
            filters,
            name="inputs_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Down-stride.
        layer_inputs = [x]
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("downstride%d" % i):
                layer_inputs.append(x)
                x = common_layers.make_even_size(x)
                if i < hparams.filter_double_steps:
                    filters *= 2
                x = common_attention.add_timing_signal_nd(x)
                x = tf.layers.conv2d(x,
                                     filters,
                                     kernel2,
                                     activation=common_layers.belu,
                                     strides=(2, 2),
                                     padding="SAME")
                x = common_layers.layer_norm(x)

        # Add embedded action if present.
        if "input_action" in features:
            action = features["input_action"][:, -1, :]
            x = self.inject_additional_input(x, action, "action_enc",
                                             hparams.action_injection)

        x, extra_loss = self.inject_latent(x, features, filters)

        # Run a stack of convolutions.
        for i in range(hparams.num_hidden_layers):
            with tf.variable_scope("layer%d" % i):
                y = tf.nn.dropout(x, 1.0 - hparams.dropout)
                y = tf.layers.conv2d(y,
                                     filters,
                                     kernel1,
                                     activation=common_layers.belu,
                                     strides=(1, 1),
                                     padding="SAME")
                if i == 0:
                    x = y
                else:
                    x = common_layers.layer_norm(x + y)

        # Up-convolve.
        layer_inputs = list(reversed(layer_inputs))
        for i in range(hparams.num_compress_steps):
            with tf.variable_scope("upstride%d" % i):
                if "input_action" in features:
                    x = self.inject_additional_input(x, action, "action_enc",
                                                     hparams.action_injection)
                if i >= hparams.num_compress_steps - hparams.filter_double_steps:
                    filters //= 2
                x = tf.layers.conv2d_transpose(x,
                                               filters,
                                               kernel2,
                                               activation=common_layers.belu,
                                               strides=(2, 2),
                                               padding="SAME")
                y = layer_inputs[i]
                shape = common_layers.shape_list(y)
                x = x[:, :shape[1], :shape[2], :]
                x = common_layers.layer_norm(x + y)
                x = common_attention.add_timing_signal_nd(x)

        # Cut down to original size.
        x = x[:, :inputs_shape[1], :inputs_shape[2], :]
        if self.is_per_pixel_softmax:
            x = tf.layers.dense(x,
                                hparams.problem.num_channels * 256,
                                name="logits")
        else:
            x = tf.layers.dense(x, hparams.problem.num_channels, name="logits")

        # Reward prediction if needed.
        if "target_reward" not in features:
            return x
        reward_pred = tf.expand_dims(  # Add a fake channels dim.
            tf.reduce_mean(x, axis=[1, 2], keepdims=True),
            axis=3)
        return {"targets": x, "target_reward": reward_pred}, extra_loss
Ejemplo n.º 32
0
  def body(self, features):
    hparams = self.hparams
    input_shape = common_layers.shape_list(features['inputs'])
    batch_size, _, frame_width, frame_height, frame_channels = input_shape  # pylint: disable=unused-variable

    # Swap time and batch axes.
    input_frames = common_video.swap_time_and_batch_axes(
        tf.to_float(features['inputs']))
    target_frames = common_video.swap_time_and_batch_axes(features['targets'])

    # Get actions if exist otherwise use zeros
    input_actions = self.get_input_if_exists(
        features, 'input_action', batch_size, hparams.video_num_input_frames)
    target_actions = self.get_input_if_exists(
        features, 'target_action', batch_size, hparams.video_num_target_frames)

    # Get rewards if exist otherwise use zeros
    # TODO(blazej) enable rewards.
    # input_rewards = self.get_input_if_exists(
    #     features, 'input_reward', batch_size, hparams.video_num_input_frames)
    # target_rewards = self.get_input_if_exists(
    #     features, 'target_reward', batch_size,hparams.video_num_target_frames)
    # all_rewards = tf.concat([input_rewards, target_rewards], axis=0)

    all_actions = tf.concat([input_actions, target_actions], axis=0)
    all_frames = tf.concat([input_frames, target_frames], axis=0)

    all_frames = tf.unstack(all_frames, axis=0)
    all_actions = tf.unstack(all_actions, axis=0)
    all_actions = [tf.squeeze(a, 1) for a in all_actions]

    # TODO(blazej) - most likely this downsize is too strong.
    all_frames = [
        tf.image.resize_images(
            image, (IMG_HEIGHT, IMG_WIDTH),
            method=tf.image.ResizeMethod.BICUBIC)
        for image in all_frames
    ]

    enc_out_all, pred_out_all, _, van_on_enc_all = construct_model(
        all_frames,
        all_actions,
        context_frames=hparams.context_frames,
        hparams=hparams,
        is_training=self.is_training)

    enc_pred_loss, _ = calc_loss_psnr(
        enc_out_all[1:],
        pred_out_all,
        'enc_pred_loss',
        hparams=hparams,
        use_l1_loss=hparams.enc_pred_use_l1_loss)

    van_on_enc_loss, _ = calc_loss_psnr(
        van_on_enc_all,
        all_frames[1:],
        'van_on_enc_loss',
        hparams=hparams)

    enc_pred_loss_scale_delay = max(hparams.enc_pred_loss_scale_delay, 1)
    enc_pred_loss_scale = tf.nn.sigmoid(
        (tf.to_float(tf.train.get_or_create_global_step()
                    ) - enc_pred_loss_scale_delay) /
        (enc_pred_loss_scale_delay * .1)) * hparams.enc_pred_loss_scale
    tf.summary.scalar('enc_pred_loss_scale', enc_pred_loss_scale)
    epva_loss = enc_pred_loss * enc_pred_loss_scale + van_on_enc_loss
    tf.summary.scalar('epva_loss', epva_loss)

    predictions = tf.stack(van_on_enc_all)

    # TODO(mbz): clean this up!
    def fix_video_dims_and_concat_on_x_axis(x):
      x = tf.transpose(x, [1, 3, 4, 0, 2])
      x = tf.reshape(x, [batch_size, frame_height, frame_channels, -1])
      x = tf.transpose(x, [0, 3, 1, 2])
      return x

    frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames)
    frames_pd = fix_video_dims_and_concat_on_x_axis(predictions)
    side_by_side_video = tf.concat([frames_gd, frames_pd], axis=1)
    tf.summary.image('full_video', side_by_side_video)

    predictions = common_video.swap_time_and_batch_axes(predictions)
    predictions = tf.slice(predictions,
                           [0, hparams.video_num_input_frames-1, 0, 0, 0],
                           [-1]*5)

    return predictions, {'extra': epva_loss}
Ejemplo n.º 33
0
def decode_transformer(encoder_output,
                       encoder_decoder_attention_bias,
                       targets,
                       hparams,
                       name,
                       task=None,
                       causal=True):
  """Original Transformer decoder."""
  orig_hparams = hparams
  with tf.variable_scope(name):
    if task is None:
      task = hparams.task
    if task == "translate":
      targets = common_layers.flatten4d3d(targets)

      decoder_input, decoder_self_bias = (
          transformer.transformer_prepare_decoder(targets, hparams))

      decoder_input = tf.nn.dropout(decoder_input,
                                    1.0 - hparams.layer_prepostprocess_dropout)

      if not causal:
        decoder_self_bias *= 0.

      decoder_output = transformer.transformer_decoder(
          decoder_input,
          encoder_output,
          decoder_self_bias,
          encoder_decoder_attention_bias,
          hparams)
      decoder_output = tf.expand_dims(decoder_output, axis=2)
    else:
      assert task == "image"
      inputs = None
      # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise
      # prepare_image will choke
      targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len,
                                     hparams.img_len,
                                     hparams.num_channels*hparams.hidden_size])

      # Prepare decoder inputs and bias.
      decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams)

      # Add class label to decoder input.
      if not hparams.drop_inputs:
        decoder_input += tf.reshape(
            inputs,
            [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size])
      decoder_output = cia.transformer_decoder_layers(
          decoder_input,
          None,
          bias,
          hparams.num_decoder_layers or hparams.num_hidden_layers,
          hparams,
          attention_type=hparams.dec_attention_type,
          name="decoder")
    decoder_output_shape = common_layers.shape_list(decoder_output)
    decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1,
                                                 hparams.hidden_size])
    # Expand since t2t expects 4d tensors.
    hparams = orig_hparams
    return decoder_output
Ejemplo n.º 34
0
  def body(self, features):
    hparams = self.hparams
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    if "training" in losses:
      plain_training_loss = losses.pop("training")
      losses["plain"] = plain_training_loss
    res_shape = common_layers.shape_list(basic_result)
    vocab_size = self._problem_hparams.target_modality.top_dimensionality
    targets = tf.one_hot(features["targets_raw"], vocab_size)
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      targets = tf.zeros_like(basic_result)
    targets = self.embed(targets)
    if hparams.autoregressive_gumbel_sample:
      basic_hot = self.gumbel_sample(basic_result)
    else:
      basic_hot = basic_result
    basic_result = self.embed(basic_hot)
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[-1]])
    targets = tf.reshape(targets, common_layers.shape_list(basic_result))
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Sometimes it's useful to look at non-autoregressive evals.
    targets_dropout = targets
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[-1]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(targets, [shape[0], -1, shape[-1]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv3")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          5,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv5")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses

    raise ValueError(
        "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
Ejemplo n.º 35
0
def level_cond_prior(prior_dist, z, latent, hparams, state):
  """Returns a conditional prior for each level.

  Args:
    prior_dist: Distribution conditioned on the previous levels.
    z: Tensor, output of the previous levels.
    latent: Tensor or a list of tensors to condition the latent_distribution.
    hparams: next_frame_glow hparams.
    state: Current LSTM state. Used only if hparams.latent_dist_encoder is
           a lstm.
  Raises:
    ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape
                of latent is different from z.
  """
  latent_dist_encoder = hparams.get("latent_dist_encoder", None)
  latent_skip = hparams.get("latent_skip", False)
  if latent_dist_encoder == "pointwise":
    last_latent = latent
    merge_std = hparams.level_scale
    latent_shape = common_layers.shape_list(latent)
    z_shape = common_layers.shape_list(z)
    if latent_shape != z_shape:
      raise ValueError("Expected latent_shape to be %s, got %s" %
                       (latent_shape, z_shape))
    latent_dist = scale_gaussian_prior(
        "latent_prior", latent, logscale_factor=3.0)
    cond_dist = merge_level_and_latent_dist(prior_dist, latent_dist,
                                            merge_std=merge_std)

  elif latent_dist_encoder == "conv_net":
    output_channels = common_layers.shape_list(z)[-1]
    last_latent = latent[-1]
    latent_stack = tf.concat([prior_dist.loc] + latent, axis=-1)
    latent_stack = noise_op(latent_stack, hparams)
    cond_dist = latent_to_dist(
        "latent_stack", latent_stack, hparams=hparams,
        output_channels=output_channels)

  elif latent_dist_encoder == "conv3d_net":
    last_latent = latent[-1]
    output_channels = common_layers.shape_list(last_latent)[-1]
    num_steps = len(latent)

    # Stack across time.
    cond_latents = tf.stack(latent, axis=1)

    # Concat latents from previous levels across channels.
    prev_latents = tf.tile(tf.expand_dims(prior_dist.loc, axis=1),
                           [1, num_steps, 1, 1, 1])
    cond_latents = tf.concat((cond_latents, prev_latents), axis=-1)
    cond_latents = noise_op(cond_latents, hparams)
    cond_dist = temporal_latent_to_dist(
        "latent_stack", cond_latents, hparams, output_channels=output_channels)

  elif latent_dist_encoder == "conv_lstm":
    last_latent = latent
    output_channels = common_layers.shape_list(z)[-1]
    latent_stack = tf.concat((prior_dist.loc, latent), axis=-1)
    latent_stack = noise_op(latent_stack, hparams)
    _, state = common_video.conv_lstm_2d(
        latent_stack, state, hparams.latent_encoder_width, kernel_size=3,
        name="conv_lstm")

    cond_dist = single_conv_dist(
        "state_to_dist", state.h, output_channels=output_channels)
  if latent_skip:
    new_mean = cond_dist.loc + last_latent
    cond_dist = tfp.distributions.Normal(new_mean, cond_dist.scale)
  return cond_dist.loc, cond_dist.scale, state
Ejemplo n.º 36
0
  def construct_predictive_tower(
      self, input_image, input_reward, action, lstm_state, latent,
      concat_latent=False):
    # Main tower
    lstm_func = common_video.conv_lstm_2d
    frame_shape = common_layers.shape_list(input_image)
    batch_size, img_height, img_width, color_channels = frame_shape
    # the number of different pixel motion predictions
    # and the number of masks for each of those predictions
    num_masks = self.hparams.num_masks
    upsample_method = self.hparams.upsample_method
    tile_and_concat = common_video.tile_and_concat

    lstm_size = self.tinyify([32, 32, 64, 64, 128, 64, 32])
    conv_size = self.tinyify([32])

    with tf.variable_scope("main", reuse=tf.AUTO_REUSE):
      hidden5, skips, layer_id = self.bottom_part_tower(
          input_image, input_reward, action, latent,
          lstm_state, lstm_size, conv_size, concat_latent=concat_latent)
      enc0, enc1 = skips

      with tf.variable_scope("upsample1", reuse=tf.AUTO_REUSE):
        enc4 = common_layers.cyclegan_upsample(
            hidden5, num_outputs=hidden5.shape.as_list()[-1],
            stride=[2, 2], method=upsample_method)

      enc1_shape = common_layers.shape_list(enc1)
      enc4 = enc4[:, :enc1_shape[1], :enc1_shape[2], :]  # Cut to shape.
      enc4 = tile_and_concat(enc4, latent, concat_latent=concat_latent)

      hidden6, lstm_state[layer_id] = lstm_func(
          enc4, lstm_state[layer_id], lstm_size[5], name="state6",
          spatial_dims=enc1_shape[1:-1])  # 16x16
      hidden6 = tile_and_concat(hidden6, latent, concat_latent=concat_latent)
      hidden6 = tfcl.layer_norm(hidden6, scope="layer_norm7")
      # Skip connection.
      hidden6 = tf.concat(axis=3, values=[hidden6, enc1])  # both 16x16
      layer_id += 1

      with tf.variable_scope("upsample2", reuse=tf.AUTO_REUSE):
        enc5 = common_layers.cyclegan_upsample(
            hidden6, num_outputs=hidden6.shape.as_list()[-1],
            stride=[2, 2], method=upsample_method)

      enc0_shape = common_layers.shape_list(enc0)
      enc5 = enc5[:, :enc0_shape[1], :enc0_shape[2], :]  # Cut to shape.
      enc5 = tile_and_concat(enc5, latent, concat_latent=concat_latent)

      hidden7, lstm_state[layer_id] = lstm_func(
          enc5, lstm_state[layer_id], lstm_size[6], name="state7",
          spatial_dims=enc0_shape[1:-1])  # 32x32
      hidden7 = tfcl.layer_norm(hidden7, scope="layer_norm8")
      layer_id += 1

      # Skip connection.
      hidden7 = tf.concat(axis=3, values=[hidden7, enc0])  # both 32x32

      with tf.variable_scope("upsample3", reuse=tf.AUTO_REUSE):
        enc6 = common_layers.cyclegan_upsample(
            hidden7, num_outputs=hidden7.shape.as_list()[-1],
            stride=[2, 2], method=upsample_method)
      enc6 = tfcl.layer_norm(enc6, scope="layer_norm9")
      enc6 = tile_and_concat(enc6, latent, concat_latent=concat_latent)

      if self.hparams.model_options == "DNA":
        # Using largest hidden state for predicting untied conv kernels.
        enc7 = tfl.conv2d_transpose(
            enc6,
            self.hparams.dna_kernel_size**2,
            [1, 1],
            strides=(1, 1),
            padding="SAME",
            name="convt4",
            activation=None)
      else:
        # Using largest hidden state for predicting a new image layer.
        enc7 = tfl.conv2d_transpose(
            enc6,
            color_channels,
            [1, 1],
            strides=(1, 1),
            padding="SAME",
            name="convt4",
            activation=None)
        # This allows the network to also generate one image from scratch,
        # which is useful when regions of the image become unoccluded.
        transformed = [tf.nn.sigmoid(enc7)]

      if self.hparams.model_options == "CDNA":
        # cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
        cdna_input = tfl.flatten(hidden5)
        transformed += common_video.cdna_transformation(
            input_image, cdna_input, num_masks, int(color_channels),
            self.hparams.dna_kernel_size, self.hparams.relu_shift)
      elif self.hparams.model_options == "DNA":
        # Only one mask is supported (more should be unnecessary).
        if num_masks != 1:
          raise ValueError("Only one mask is supported for DNA model.")
        transformed = [
            common_video.dna_transformation(
                input_image, enc7,
                self.hparams.dna_kernel_size, self.hparams.relu_shift)]

      masks = tfl.conv2d(
          enc6, filters=num_masks + 1, kernel_size=[1, 1],
          strides=(1, 1), name="convt7", padding="SAME")
      masks = masks[:, :img_height, :img_width, ...]
      masks = tf.reshape(
          tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])),
          [batch_size,
           int(img_height),
           int(img_width), num_masks + 1])
      mask_list = tf.split(
          axis=3, num_or_size_splits=num_masks + 1, value=masks)
      output = mask_list[0] * input_image
      for layer, mask in zip(transformed, mask_list[1:]):
        # TODO(mbz): take another look at this logic and verify.
        output = output[:, :img_height, :img_width, :]
        layer = layer[:, :img_height, :img_width, :]
        output += layer * mask

      # Map to softmax digits
      if self.is_per_pixel_softmax:
        output = tf.layers.dense(
            output, self.hparams.problem.num_channels * 256, name="logits")

      mid_outputs = [enc0, enc1, enc4, enc5, enc6]
      return output, lstm_state, mid_outputs
Ejemplo n.º 37
0
def conv2d_fixed_padding(inputs,
                         filters,
                         kernel_size,
                         strides,
                         data_format="channels_first",
                         use_td=False,
                         targeting_rate=None,
                         keep_prob=None,
                         is_training=None):
    """Strided 2-D convolution with explicit padding.

  The padding is consistent and is based only on `kernel_size`, not on the
  dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).

  Args:
    inputs: `Tensor` of size `[batch, channels, height_in, width_in]`.
    filters: `int` number of filters in the convolution.
    kernel_size: `int` size of the kernel to be used in the convolution.
    strides: `int` strides of the convolution.
    data_format: `str` either "channels_first" for `[batch, channels, height,
        width]` or "channels_last for `[batch, height, width, channels]`.
    use_td: `str` one of "weight" or "unit". Set to False or "" to disable
      targeted dropout.
    targeting_rate: `float` proportion of weights to target with targeted
      dropout.
    keep_prob: `float` keep probability for targeted dropout.
    is_training: `bool` for whether the model is in training.

  Returns:
    A `Tensor` of shape `[batch, filters, height_out, width_out]`.

  Raises:
    Exception: if use_td is not valid.
  """
    if strides > 1:
        inputs = fixed_padding(inputs, kernel_size, data_format=data_format)

    if use_td:
        inputs_shape = common_layers.shape_list(inputs)
        if use_td == "weight":
            if data_format == "channels_last":
                size = kernel_size * kernel_size * inputs_shape[-1]
            else:
                size = kernel_size * kernel_size * inputs_shape[1]
            targeting_count = targeting_rate * tf.to_float(size)
            targeting_fn = common_layers.weight_targeting
        elif use_td == "unit":
            targeting_count = targeting_rate * filters
            targeting_fn = common_layers.unit_targeting
        else:
            raise Exception("Unrecognized targeted dropout type: %s" % use_td)

        y = common_layers.td_conv(
            inputs,
            filters,
            kernel_size,
            targeting_count,
            targeting_fn,
            keep_prob,
            is_training,
            do_prune=True,
            strides=strides,
            padding=("SAME" if strides == 1 else "VALID"),
            data_format=data_format,
            use_bias=False,
            kernel_initializer=tf.variance_scaling_initializer())
    else:
        y = tf.layers.conv2d(
            inputs=inputs,
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=("SAME" if strides == 1 else "VALID"),
            use_bias=False,
            kernel_initializer=tf.variance_scaling_initializer(),
            data_format=data_format)

    return y
Ejemplo n.º 38
0
 def targets_bottom(self, x):
     with tf.variable_scope(self.name):
         return tf.zeros([
             common_layers.shape_list(x)[0], 1, 1,
             self._model_hparams.hidden_size
         ])
Ejemplo n.º 39
0
    def bottom(self, x):
        """Use batchnorm instead of CMVN and shorten the stft with strided convs.

    Args:
      x: float32 tensor with shape [batch_size, len, 1, freqs * channels]

    Returns:
      float32 tensor with shape [batch_size, shorter_len, 1, hidden_size]
    """
        inputs = x
        p = self._model_hparams

        num_mel_bins = p.audio_num_mel_bins
        num_channels = 3 if p.audio_add_delta_deltas else 1

        with tf.variable_scope(self.name):
            if p.audio_preproc_in_bottom:
                # Compute filterbanks
                with tf.variable_scope("fbanks"):
                    waveforms = tf.squeeze(inputs, [2, 3])
                    mel_fbanks = common_audio.compute_mel_filterbank_features(
                        waveforms,
                        sample_rate=p.audio_sample_rate,
                        dither=p.audio_dither,
                        preemphasis=p.audio_preemphasis,
                        frame_length=p.audio_frame_length,
                        frame_step=p.audio_frame_step,
                        lower_edge_hertz=p.audio_lower_edge_hertz,
                        upper_edge_hertz=p.audio_upper_edge_hertz,
                        num_mel_bins=p.audio_num_mel_bins,
                        apply_mask=True)
                    if p.audio_add_delta_deltas:
                        mel_fbanks = common_audio.add_delta_deltas(mel_fbanks)
                    x = tf.reshape(
                        mel_fbanks,
                        common_layers.shape_list(mel_fbanks)[:2] +
                        [num_mel_bins, num_channels])

                    nonpadding_mask = 1. - common_attention.embedding_to_padding(
                        x)
                    num_of_nonpadding_elements = tf.reduce_sum(
                        nonpadding_mask) * num_mel_bins * num_channels

                    # This replaces CMVN estimation on data
                    var_epsilon = 1e-09
                    mean = tf.reduce_sum(x, axis=[
                        1
                    ], keepdims=True) / num_of_nonpadding_elements
                    variance = (
                        num_of_nonpadding_elements * mean**2. -
                        2. * mean * tf.reduce_sum(x, axis=[1], keepdims=True) +
                        tf.reduce_sum(x**2, axis=[1], keepdims=True)
                    ) / num_of_nonpadding_elements
                    x = (x - mean) * tf.rsqrt(variance +
                                              var_epsilon) * tf.expand_dims(
                                                  nonpadding_mask, -1)
            else:
                x = inputs

            # The convention is that the models are flattened along the spatial,
            # dimensions, thus the speech preprocessor treats frequencies and
            # channels as image colors (last axis)
            x.set_shape([None, None, num_mel_bins, num_channels])

            # TODO(chorowski): how to specify bottom's hparams and avoid hardcoding?
            x = tf.pad(x, [[0, 0], [0, 8], [0, 0], [0, 0]])
            for _ in range(2):
                x = tf.layers.conv2d(x, 128, (3, 3), (2, 2), use_bias=False)
                x = common_layers.layer_norm(x)
                x = tf.nn.relu(x)

            xshape = common_layers.shape_list(x)
            # apply a conv that will remove all frequencies and at the same time
            # project the output into desired hidden_size
            x = tf.pad(x, [[0, 0], [0, 2], [0, 0], [0, 0]])
            x = tf.layers.conv2d(x,
                                 p.hidden_size, (3, xshape[2]),
                                 use_bias=False)

            assert common_layers.shape_list(x)[2] == 1
            x = common_layers.layer_norm(x)
            x = tf.nn.relu(x)
        return x
Ejemplo n.º 40
0
    def __process(self, all_frames, all_actions, all_rewards, all_raw_frames):
        """Main video processing function."""
        hparams = self.hparams
        all_frames_copy = [tf.identity(frame) for frame in all_frames]
        orig_frame_shape = common_layers.shape_list(all_frames[0])
        batch_size = orig_frame_shape[0]
        ss_func = self.get_scheduled_sample_func(batch_size)
        target_frames = []
        extra_loss = 0.0

        # Any extra info required by the model goes into here.
        video_features = self.video_features(all_frames, all_actions,
                                             all_rewards, all_raw_frames)

        num_frames = len(all_frames)
        if self.is_recurrent_model:
            input_index_range = range(num_frames - 1)
        else:
            input_index_range = range(hparams.video_num_target_frames)

        # Setup the internal states as well as an auxiliary tf op
        # to enforce syncronization between prediction steps.
        if self.internal_states is None:
            internal_states = None
            sync_op = tf.no_op()
        else:
            internal_states = self.load_internal_states_ops()
            with tf.control_dependencies(flat_lists(internal_states)):
                sync_op = tf.no_op()

        res_frames, sampled_frames, res_rewards = [], [], []
        for i in input_index_range:
            with tf.control_dependencies([sync_op]):
                frames, actions, rewards, target_index = self.__get_next_inputs(
                    i, all_frames, all_actions, all_rewards)
                target_frame = all_frames[target_index]
                target_frames.append(tf.identity(target_frame))

                with tf.variable_scope(tf.get_variable_scope(),
                                       reuse=tf.AUTO_REUSE):
                    func_in = (frames, actions, rewards, target_frame,
                               internal_states, video_features)
                    func_out = self.next_frame(*func_in)
                    res_frame, res_reward, res_extra_loss, internal_states = func_out
                    res_frames.append(res_frame)
                    res_rewards.append(res_reward)
                    extra_loss += res_extra_loss / float(
                        len(input_index_range))

                    # Syncronizing the internals states
                    # Some Tensflow Magic to make sure everything happens as it should.
                    with tf.control_dependencies([res_frame]):
                        sync_op = tf.no_op()
                        if self.is_predicting and self.is_recurrent_model and i == 0:
                            # The internal state save happens at the end of the 1st iteration
                            # which essentially allows recurrent models to continue
                            # running after one prediction.
                            # Necessary for planning/rl applications.
                            save_ops = self.save_internal_states_ops(
                                internal_states)
                            with tf.control_dependencies(flat_lists(save_ops)):
                                sync_op = tf.no_op()

                # Only for Softmax loss: sample frame so we can keep iterating.
                sampled_frame = self.get_sampled_frame(res_frame)
                sampled_frames.append(sampled_frame)

                # Check whether we are done with context frames or not
                if self.is_recurrent_model:
                    done_warm_start = (i >= hparams.video_num_input_frames - 1)
                else:
                    done_warm_start = True  # Always true for non-reccurent networks.

                if self.is_predicting and done_warm_start:
                    all_frames[target_index] = sampled_frame

                # Scheduled sampling during training.
                if self.is_training:
                    groundtruth_items = [target_frame]
                    generated_items = [sampled_frame]
                    ss_frame, = self.get_scheduled_sample_inputs(
                        done_warm_start, groundtruth_items, generated_items,
                        ss_func)
                    all_frames[target_index] = ss_frame

        video_extra_loss = self.video_extra_loss(sampled_frames, target_frames,
                                                 internal_states,
                                                 video_features)
        tf.summary.scalar("video_extra_loss", video_extra_loss)
        extra_loss += video_extra_loss

        if self.is_recurrent_model:
            has_input_predictions = hparams.video_num_input_frames > 1
            if self.is_training and hparams.internal_loss and has_input_predictions:
                # add the loss for input frames as well.
                extra_gts = all_frames_copy[1:hparams.video_num_input_frames]
                extra_raw_gts = all_raw_frames[1:hparams.
                                               video_num_input_frames]
                extra_pds = res_frames[:hparams.video_num_input_frames - 1]
                recon_loss = self.get_extra_internal_loss(
                    extra_raw_gts, extra_gts, extra_pds)
                extra_loss += recon_loss
            # Cut the predicted input frames.
            res_frames = res_frames[hparams.video_num_input_frames - 1:]
            res_rewards = res_rewards[hparams.video_num_input_frames - 1:]
            sampled_frames = sampled_frames[hparams.video_num_input_frames -
                                            1:]
            target_frames = target_frames[hparams.video_num_input_frames - 1:]

        self.visualize_predictions(sampled_frames, target_frames)

        output_frames = tf.stack(res_frames, axis=1)
        targets = output_frames

        if self.has_rewards:
            output_rewards = tf.stack(res_rewards, axis=1)
            targets = {
                "targets": output_frames,
                "target_reward": output_rewards
            }

        return targets, extra_loss
Ejemplo n.º 41
0
    def infer(self, features, *args, **kwargs):  # pylint: disable=arguments-differ
        """Produce predictions from the model by running it."""
        del args, kwargs
        # Inputs and features preparation needed to handle edge cases.
        if not features:
            features = {}
        hparams = self.hparams
        inputs_old = None
        if "inputs" in features and len(features["inputs"].shape) < 4:
            inputs_old = features["inputs"]
            features["inputs"] = tf.expand_dims(features["inputs"], 2)

        def logits_to_samples(logits):
            """Get samples from logits."""
            # If the last dimension is 1 then we're using L1/L2 loss.
            if common_layers.shape_list(logits)[-1] == 1:
                return tf.to_int32(tf.squeeze(logits, axis=-1))
            # Argmax in TF doesn't handle more than 5 dimensions yet.
            logits_shape = common_layers.shape_list(logits)
            argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]),
                               axis=-1)
            return tf.reshape(argmax, logits_shape[:-1])

        # Get predictions.
        try:
            num_channels = hparams.problem.num_channels
        except AttributeError:
            num_channels = 1
        if "inputs" in features:
            inputs_shape = common_layers.shape_list(features["inputs"])
            targets_shape = [
                inputs_shape[0], hparams.video_num_target_frames,
                inputs_shape[2], inputs_shape[3], num_channels
            ]
        else:
            tf.logging.warn("Guessing targets shape as no inputs are given.")
            targets_shape = [
                hparams.batch_size, hparams.video_num_target_frames, 1, 1,
                num_channels
            ]

        features["targets"] = tf.zeros(targets_shape, dtype=tf.int32)
        reward_in_mod = "target_reward" in hparams.problem_hparams.modality
        action_in_mod = "target_action" in hparams.problem_hparams.modality
        if reward_in_mod:
            # TODO(lukaszkaiser): this is a hack. get the actual reward history.
            if "input_reward" not in features:
                features["input_reward"] = tf.zeros(
                    [inputs_shape[0], inputs_shape[1], 1], dtype=tf.int32)
            features["target_reward"] = tf.zeros(
                [targets_shape[0], targets_shape[1], 1], dtype=tf.int32)
        if action_in_mod and "target_action" not in features:
            features["target_action"] = tf.zeros(
                [targets_shape[0], targets_shape[1], 1], dtype=tf.int32)
        logits, _ = self(features)  # pylint: disable=not-callable
        if isinstance(logits, dict):
            results = {}
            for k, v in six.iteritems(logits):
                results[k] = logits_to_samples(v)
                results["%s_logits" % k] = v
            # HACK: bypassing decoding issues.
            results["outputs"] = results["targets"]
            results["scores"] = results["targets"]
        else:
            results = logits_to_samples(logits)

        # Restore inputs to not confuse Estimator in edge cases.
        if inputs_old is not None:
            features["inputs"] = inputs_old

        # Return results.
        return results
Ejemplo n.º 42
0
 def targets_bottom(self, x):
     with tf.variable_scope(self.name):
         return tf.zeros(
             [common_layers.shape_list(x)[0], 1, 1, self._body_input_depth])
Ejemplo n.º 43
0
    def construct_model(self, images, actions, rewards):
        """Builds the stochastic model.

    The model first encodes all the images (x_t) in the sequence
    using the encoder. Let"s call the output e_t. Then it predicts the
    latent state of the next frame using a recurrent posterior network
    z ~ q(z|e_{0:t}) = N(mu(e_{0:t}), sigma(e_{0:t})).
    Another recurrent network predicts the embedding of the next frame
    using the approximated posterior e_{t+1} = p(e_{t+1}|e_{0:t}, z)
    Finally, the decoder decodes e_{t+1} into x_{t+1}.
    Skip connections from encoder to decoder help with reconstruction.

    Args:
      images: tensor of ground truth image sequences
      actions: NOT used list of action tensors
      rewards: NOT used list of reward tensors

    Returns:
      gen_images: generated images
      fakr_rewards: input rewards as reward prediction!
      pred_mu: predited means of posterior
      pred_logvar: predicted log(var) of posterior
    """
        # model does not support action conditioned and reward prediction
        fake_reward_prediction = rewards
        del actions, rewards

        z_dim = self.hparams.z_dim
        g_dim = self.hparams.g_dim
        rnn_size = self.hparams.rnn_size
        posterior_rnn_layers = self.hparams.posterior_rnn_layers
        predictor_rnn_layers = self.hparams.predictor_rnn_layers
        context_frames = self.hparams.video_num_input_frames

        seq_len, batch_size, _, _, color_channels = common_layers.shape_list(
            images)

        # LSTM initial sizesstates.
        predictor_states = [None] * predictor_rnn_layers
        posterior_states = [None] * posterior_rnn_layers

        tf.logging.info(">>>> Encoding")
        # Encoding:
        enc_images, enc_skips = [], []
        images = tf.unstack(images, axis=0)
        for i, image in enumerate(images):
            with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
                enc, skips = self.encoder(image, rnn_size)
                enc = tfcl.flatten(enc)
                enc_images.append(enc)
                enc_skips.append(skips)

        tf.logging.info(">>>> Prediction")
        # Prediction
        pred_enc, pred_mu, pred_logvar = [], [], []
        for i in range(1, seq_len):
            with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE):
                # current encoding
                h_current = enc_images[i - 1]
                # target encoding
                h_target = enc_images[i]

                z = tf.random_normal([batch_size, z_dim],
                                     0,
                                     1,
                                     dtype=tf.float32)
                mu, logvar = tf.zeros_like(z), tf.zeros_like(z)

                # Only use Posterior if it's training time
                if self.hparams.mode == tf.estimator.ModeKeys.TRAIN:
                    mu, logvar, posterior_states = self.lstm_gaussian(
                        h_target, posterior_states, rnn_size, z_dim,
                        posterior_rnn_layers)

                    # The original implementation has a multiplier of 0.5
                    # Removed here for simplicity i.e. replacing var with std
                    z = z * tf.exp(logvar) + mu

                # Predict output encoding
                h_pred, predictor_states = self.stacked_lstm(
                    tf.concat([h_current, z], axis=1), predictor_states,
                    rnn_size, g_dim, predictor_rnn_layers)

                pred_enc.append(h_pred)
                pred_mu.append(mu)
                pred_logvar.append(logvar)

        tf.logging.info(">>>> Decoding")
        # Decoding
        gen_images = []
        for i in range(seq_len - 1):
            with tf.variable_scope("decoding", reuse=tf.AUTO_REUSE):
                # use skip values of last available frame
                skip_index = min(context_frames - 1, i)

                h_pred = tf.reshape(pred_enc[i], [batch_size, 1, 1, g_dim])
                x_pred = self.decoder(h_pred, enc_skips[skip_index],
                                      color_channels)
                gen_images.append(x_pred)

        tf.logging.info(">>>> Done")
        gen_images = tf.stack(gen_images, axis=0)
        return gen_images, fake_reward_prediction, pred_mu, pred_logvar
Ejemplo n.º 44
0
  def body(self, features):
    hparams = self.hparams
    batch_size = common_layers.shape_list(features["inputs"])[0]

    # Swap time and batch axes.
    input_frames = common_video.swap_time_and_batch_axes(features["inputs"])
    target_frames = common_video.swap_time_and_batch_axes(features["targets"])

    # Get actions if exist otherwise use zeros
    input_actions = self.get_input_if_exists(
        features, "input_action", batch_size, hparams.video_num_input_frames)
    target_actions = self.get_input_if_exists(
        features, "target_action", batch_size, hparams.video_num_target_frames)

    # Get rewards if exist otherwise use zeros
    input_rewards = self.get_input_if_exists(
        features, "input_reward", batch_size, hparams.video_num_input_frames)
    target_rewards = self.get_input_if_exists(
        features, "target_reward", batch_size, hparams.video_num_target_frames)

    all_actions = tf.concat([input_actions, target_actions], axis=0)
    all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
    all_frames = tf.concat([input_frames, target_frames], axis=0)

    # Each image is being used twice, in latent tower and main tower.
    # This is to make sure we are using the *same* image for both, ...
    # ... given how TF queues work.
    # NOT sure if this is required at all. Doesn"t hurt though! :)
    all_frames = tf.identity(all_frames)

    gen_images, gen_rewards, latent_means, latent_stds = self.construct_model(
        images=all_frames,
        actions=all_actions,
        rewards=all_rewards,
    )

    extra_loss = self.get_extra_loss(
        latent_means=latent_means,
        latent_stds=latent_stds,
        true_frames=all_frames,
        gen_frames=gen_images)

    # Visualize predictions in Tensorboard
    if self.is_training:
      self.visualize_predictions(all_frames[1:], gen_images)

    # Ignore the predictions from the input frames.
    # This is NOT the same as original paper/implementation.
    predictions = gen_images[hparams.video_num_input_frames-1:]
    reward_pred = gen_rewards[hparams.video_num_input_frames-1:]
    reward_pred = tf.squeeze(reward_pred, axis=2)  # Remove extra dimension.

    # Swap back time and batch axes.
    predictions = common_video.swap_time_and_batch_axes(predictions)
    reward_pred = common_video.swap_time_and_batch_axes(reward_pred)

    if self.is_training and hparams.internal_loss:
      # add the loss for input frames as well.
      extra_gts = all_frames[1:hparams.video_num_input_frames]
      extra_gts = common_video.swap_time_and_batch_axes(extra_gts)
      extra_pds = gen_images[:hparams.video_num_input_frames-1]
      extra_pds = common_video.swap_time_and_batch_axes(extra_pds)
      extra_raw_gts = features["inputs_raw"][:, 1:]
      recon_loss = self.get_extra_internal_loss(
          extra_raw_gts, extra_gts, extra_pds)
      extra_loss += recon_loss

    return_targets = predictions
    if hparams.reward_prediction:
      return_targets = {"targets": predictions, "target_reward": reward_pred}

    return return_targets, extra_loss
Ejemplo n.º 45
0
  def construct_model(self,
                      images,
                      actions,
                      rewards):
    """Build convolutional lstm video predictor using CDNA, or DNA.

    Args:
      images: list of tensors of ground truth image sequences
              there should be a 4D image ?xWxHxC for each timestep
      actions: list of action tensors
               each action should be in the shape ?x1xZ
      rewards: list of reward tensors
               each reward should be in the shape ?x1xZ
    Returns:
      gen_images: predicted future image frames
      gen_rewards: predicted future rewards
      latent_mean: mean of approximated posterior
      latent_std: std of approximated posterior

    Raises:
      ValueError: if more than 1 mask specified for DNA model.
    """
    context_frames = self.hparams.video_num_input_frames
    buffer_size = self.hparams.reward_prediction_buffer_size
    if buffer_size == 0:
      buffer_size = context_frames
    if buffer_size > context_frames:
      raise ValueError("Buffer size is bigger than context frames %d %d." %
                       (buffer_size, context_frames))

    batch_size = common_layers.shape_list(images[0])[0]
    ss_func = self.get_scheduled_sample_func(batch_size)

    def process_single_frame(prev_outputs, inputs):
      """Process a single frame of the video."""
      cur_image, input_reward, action = inputs
      time_step, prev_image, prev_reward, frame_buf, lstm_states = prev_outputs

      # sample from softmax (by argmax). this is noop for non-softmax loss.
      prev_image = self.get_sampled_frame(prev_image)

      generated_items = [prev_image]
      groundtruth_items = [cur_image]
      done_warm_start = tf.greater(time_step, context_frames - 1)
      input_image, = self.get_scheduled_sample_inputs(
          done_warm_start, groundtruth_items, generated_items, ss_func)

      # Prediction
      pred_image, lstm_states, _ = self.construct_predictive_tower(
          input_image, None, action, lstm_states, latent)

      if self.hparams.reward_prediction:
        reward_input_image = self.get_sampled_frame(pred_image)
        if self.hparams.reward_prediction_stop_gradient:
          reward_input_image = tf.stop_gradient(reward_input_image)
        with tf.control_dependencies([time_step]):
          frame_buf = [reward_input_image] + frame_buf[:-1]
        pred_reward = self.reward_prediction(frame_buf, None, action, latent)
        pred_reward = common_video.decode_to_shape(
            pred_reward, common_layers.shape_list(input_reward), "reward_dec")
      else:
        pred_reward = prev_reward

      time_step += 1
      outputs = (time_step, pred_image, pred_reward, frame_buf, lstm_states)

      return outputs

    # Latent tower
    latent = None
    if self.hparams.stochastic_model:
      latent_mean, latent_std = self.construct_latent_tower(images, time_axis=0)
      latent = common_video.get_gaussian_tensor(latent_mean, latent_std)

    # HACK: Do first step outside to initialize all the variables

    lstm_states = [None] * (5 if self.hparams.small_mode else 7)
    frame_buffer = [tf.zeros_like(images[0])] * buffer_size
    inputs = images[0], rewards[0], actions[0]
    init_image_shape = common_layers.shape_list(images[0])
    if self.is_per_pixel_softmax:
      init_image_shape[-1] *= 256
    init_image = tf.zeros(init_image_shape, dtype=images.dtype)
    prev_outputs = (tf.constant(0),
                    init_image,
                    tf.zeros_like(rewards[0]),
                    frame_buffer,
                    lstm_states)

    initializers = process_single_frame(prev_outputs, inputs)
    first_gen_images = tf.expand_dims(initializers[1], axis=0)
    first_gen_rewards = tf.expand_dims(initializers[2], axis=0)

    inputs = (images[1:-1], rewards[1:-1], actions[1:-1])

    outputs = tf.scan(process_single_frame, inputs, initializers)
    gen_images, gen_rewards = outputs[1:3]

    gen_images = tf.concat((first_gen_images, gen_images), axis=0)
    gen_rewards = tf.concat((first_gen_rewards, gen_rewards), axis=0)

    if self.hparams.stochastic_model:
      return gen_images, gen_rewards, [latent_mean], [latent_std]
    else:
      return gen_images, gen_rewards, None, None
Ejemplo n.º 46
0
  def next_frame(self, frames, actions, rewards, target_frame,
                 internal_states, video_extra):
    del rewards, video_extra

    hparams = self.hparams
    filters = hparams.hidden_size
    kernel2 = (4, 4)
    action = actions[-1]

    # Stack the inputs.
    if internal_states is not None and hparams.concat_internal_states:
      # Use the first part of the first internal state if asked to concatenate.
      batch_size = common_layers.shape_list(frames[0])[0]
      internal_state = internal_states[0][0][:batch_size, :, :, :]
      stacked_frames = tf.concat(frames + [internal_state], axis=-1)
    else:
      stacked_frames = tf.concat(frames, axis=-1)
    inputs_shape = common_layers.shape_list(stacked_frames)

    # Update internal states early if requested.
    if hparams.concat_internal_states:
      internal_states = self.update_internal_states_early(
          internal_states, frames)

    # Using non-zero bias initializer below for edge cases of uniform inputs.
    x = tf.layers.dense(
        stacked_frames, filters, name="inputs_embed",
        bias_initializer=tf.random_normal_initializer(stddev=0.01))
    x = common_attention.add_timing_signal_nd(x)

    # Down-stride.
    layer_inputs = [x]
    for i in range(hparams.num_compress_steps):
      with tf.variable_scope("downstride%d" % i):
        layer_inputs.append(x)
        x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
        x = common_layers.make_even_size(x)
        if i < hparams.filter_double_steps:
          filters *= 2
        x = common_attention.add_timing_signal_nd(x)
        x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu,
                             strides=(2, 2), padding="SAME")
        x = common_layers.layer_norm(x)

    if self.has_actions:
      with tf.variable_scope("policy"):
        x_flat = tf.layers.flatten(x)
        policy_pred = tf.layers.dense(x_flat, self.hparams.problem.num_actions)
        value_pred = tf.layers.dense(x_flat, 1)
        value_pred = tf.squeeze(value_pred, axis=-1)
    else:
      policy_pred, value_pred = None, None

    # Add embedded action if present.
    if self.has_actions:
      x = common_video.inject_additional_input(
          x, action, "action_enc", hparams.action_injection)

    # Inject latent if present. Only for stochastic models.
    x, extra_loss = self.inject_latent(x, frames, target_frame, action)

    x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
    x, internal_states = self.middle_network(x, internal_states)

    # Up-convolve.
    layer_inputs = list(reversed(layer_inputs))
    for i in range(hparams.num_compress_steps):
      with tf.variable_scope("upstride%d" % i):
        x = tf.nn.dropout(x, 1.0 - self.hparams.dropout)
        if self.has_actions:
          x = common_video.inject_additional_input(
              x, action, "action_enc", hparams.action_injection)
        if i >= hparams.num_compress_steps - hparams.filter_double_steps:
          filters //= 2
        x = tf.layers.conv2d_transpose(
            x, filters, kernel2, activation=common_layers.belu,
            strides=(2, 2), padding="SAME")
        y = layer_inputs[i]
        shape = common_layers.shape_list(y)
        x = x[:, :shape[1], :shape[2], :]
        x = common_layers.layer_norm(x + y)
        x = common_attention.add_timing_signal_nd(x)

    # Cut down to original size.
    x = x[:, :inputs_shape[1], :inputs_shape[2], :]
    x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
    if self.is_per_pixel_softmax:
      x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits")
    else:
      x = tf.layers.dense(x, hparams.problem.num_channels, name="logits")

    reward_pred = None
    if self.has_rewards:
      # Reward prediction based on middle and final logits.
      reward_pred = tf.concat([x_mid, x_fin], axis=-1)
      reward_pred = tf.nn.relu(tf.layers.dense(
          reward_pred, 128, name="reward_pred"))
      reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims
      reward_pred = tf.squeeze(reward_pred, axis=1)  # Remove extra dims

    return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states
def lstm_attention_search_based_decoder(inputs, hparams, train, name,
                                        initial_state, encoder_outputs,
                                        build_storage, storage, n):
    """Run LSTM cell with attention on inputs of shape [batch x time x size]."""
    def dropout_lstm_cell():
        return tf.contrib.rnn.DropoutWrapper(
            LSTMShallowFusionCell(hparams.hidden_size, build_storage, storage),
            input_keep_prob=1.0 - hparams.dropout * tf.to_float(train))

    layers = [dropout_lstm_cell() for _ in range(hparams.num_hidden_layers)]
    if hparams.attention_mechanism == "luong":
        attention_mechanism_class = tf.contrib.seq2seq.LuongAttention
    elif hparams.attention_mechanism == "bahdanau":
        attention_mechanism_class = tf.contrib.seq2seq.BahdanauAttention
    else:
        raise ValueError("Unknown hparams.attention_mechanism = %s, must be "
                         "luong or bahdanau." % hparams.attention_mechanism)
    attention_mechanism = attention_mechanism_class(hparams.hidden_size,
                                                    encoder_outputs)

    if not build_storage:
        p_copy = [
            tf.TensorArray(tf.float32,
                           size=tf.shape(inputs)[1],
                           dynamic_size=True,
                           name='dzeta_dot_q'),
            tf.TensorArray(tf.float32,
                           size=tf.shape(inputs)[1],
                           dynamic_size=True,
                           name='1_dzeta')
        ]
    else:
        p_copy = None
    # TODO: add fusion_type in hparams
    cell = AttentionWrapperSearchBased(
        tf.nn.rnn_cell.MultiRNNCell(layers),
        [attention_mechanism] * hparams.num_heads,
        storage=storage,
        build_storage=build_storage,
        p_copy=p_copy,
        start_index=n,
        attention_layer_size=[hparams.attention_layer_size] *
        hparams.num_heads,
        output_attention=(hparams.output_attention == 1))

    batch_size = common_layers.shape_list(inputs)[0]

    initial_state = cell.zero_state(batch_size,
                                    tf.float32).clone(cell_state=initial_state)

    with tf.variable_scope(name):
        output, state = tf.nn.dynamic_rnn(cell,
                                          inputs,
                                          initial_state=initial_state,
                                          dtype=tf.float32,
                                          time_major=False)

        # For multi-head attention project output back to hidden size
        if hparams.output_attention == 1 and hparams.num_heads > 1:
            output = tf.layers.dense(output, hparams.hidden_size)

        return output, p_copy
Ejemplo n.º 48
0
    def body(self, features):
        hp = self.hparams
        # pylint: disable=eval-used
        if hp.image_input_type == "image":
            image_feat = vqa_layers.image_embedding(
                features["inputs"],
                model_fn=eval(hp.image_model_fn),
                trainable=hp.train_resnet,
                is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
        else:
            image_feat = features["inputs"]

        image_feat = common_layers.flatten4d3d(image_feat)
        image_feat = common_layers.dense(image_feat, hp.hidden_size)
        utils.collect_named_outputs("norms", "image_feat_after_proj",
                                    tf.norm(image_feat, axis=-1))

        question = common_layers.flatten4d3d(features["question"])
        utils.collect_named_outputs("norms", "question_embedding",
                                    tf.norm(question, axis=-1))
        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias) = prepare_image_question_encoder(
             image_feat, question, hp)

        encoder_input = tf.nn.dropout(encoder_input,
                                      keep_prob=1. -
                                      hp.layer_prepostprocess_dropout)

        encoder_output, _ = recurrent_transformer_decoder(
            encoder_input,
            None,
            encoder_self_attention_bias,
            None,
            hp,
            name="encoder")
        utils.collect_named_outputs("norms", "encoder_output",
                                    tf.norm(encoder_output, axis=-1))

        # scale query by sqrt(hidden_size)
        query = tf.get_variable("query",
                                [hp.hidden_size]) * hp.hidden_size**0.5
        query = tf.expand_dims(tf.expand_dims(query, axis=0), axis=0)
        batch_size = common_layers.shape_list(encoder_input)[0]
        query = tf.tile(query, [batch_size, 1, 1])
        query = tf.nn.dropout(query,
                              keep_prob=1. - hp.layer_prepostprocess_dropout)

        decoder_output, _ = recurrent_transformer_decoder(
            query,
            encoder_output,
            None,
            encoder_decoder_attention_bias,
            hp,
            name="decoder")
        utils.collect_named_outputs("norms", "decoder_output",
                                    tf.norm(decoder_output, axis=-1))

        norm_tensors = utils.convert_collection_to_dict("norms")
        vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

        # Expand dimension 1 and 2
        return tf.expand_dims(decoder_output, axis=1)
Ejemplo n.º 49
0
def conv(name, x, output_channels, filter_size=None, stride=None,
         logscale_factor=3.0, apply_actnorm=True, conv_init="default",
         dilations=None):
  """Convolutional layer with edge bias padding and optional actnorm.

  If x is 5-dimensional, actnorm is applied independently across every
  time-step.

  Args:
    name: variable scope.
    x: 4-D Tensor or 5-D Tensor of shape NHWC or NTHWC
    output_channels: Number of output channels.
    filter_size: list of ints, if None [3, 3] and [2, 3, 3] are defaults for
                 4-D and 5-D input tensors respectively.
    stride: list of ints, default stride: 1
    logscale_factor: see actnorm for parameter meaning.
    apply_actnorm: if apply_actnorm the activations of the first minibatch
                   have zero mean and unit variance. Else, there is no scaling
                   applied.
    conv_init: default or zeros. default is a normal distribution with 0.05 std.
    dilations: List of integers, apply dilations.
  Returns:
    x: actnorm(conv2d(x))
  Raises:
    ValueError: if init is set to "zeros" and apply_actnorm is set to True.
  """
  if conv_init == "zeros" and apply_actnorm:
    raise ValueError("apply_actnorm is unstable when init is set to zeros.")

  x_shape = common_layers.shape_list(x)
  is_2d = len(x_shape) == 4
  num_steps = x_shape[1]

  # set filter_size, stride and in_channels
  if is_2d:
    if filter_size is None:
      filter_size = [3, 3]
    if stride is None:
      stride = [1, 1]
    if dilations is None:
      dilations = [1, 1, 1, 1]
    actnorm_func = actnorm
    x = add_edge_bias(x, filter_size=filter_size)
    conv_filter = tf.nn.conv2d
  else:
    if filter_size is None:
      if num_steps == 1:
        filter_size = [1, 3, 3]
      else:
        filter_size = [2, 3, 3]
    if stride is None:
      stride = [1, 1, 1]
    if dilations is None:
      dilations = [1, 1, 1, 1, 1]
    actnorm_func = actnorm_3d
    x = time_pad(x, filter_size=filter_size, dilations=dilations)
    conv_filter = tf.nn.conv3d

  in_channels = common_layers.shape_list(x)[-1]
  filter_shape = filter_size + [in_channels, output_channels]
  stride_shape = [1] + stride + [1]

  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):

    if conv_init == "default":
      initializer = default_initializer()
    elif conv_init == "zeros":
      initializer = tf.zeros_initializer()

    w = tf.get_variable("W", filter_shape, tf.float32, initializer=initializer)
    x = conv_filter(x, w, stride_shape, padding="VALID", dilations=dilations)
    if apply_actnorm:
      x, _ = actnorm_func("actnorm", x, logscale_factor=logscale_factor)
    else:
      x += tf.get_variable("b", [1, 1, 1, output_channels],
                           initializer=tf.zeros_initializer())
      logs = tf.get_variable("logs", [1, output_channels],
                             initializer=tf.zeros_initializer())
      x *= tf.exp(logs * logscale_factor)
    return x
Ejemplo n.º 50
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  if inputs is not None:
    batch_size = common_layers.shape_list(inputs)[0]
  else:
    batch_size = common_layers.shape_list(targets)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
    inputs_ex, ed_ex = inputs, ed
  else:
    ed, inputs_ex, ed_ex = None, None, None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    # flatten here
    original_targets_shape = tf.shape(targets)
    if hparams.task == "image":
      cia.maybe_reshape_4d_to_3d(targets)
    if hparams.task == "translate":
      max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    else:
      assert hparams.task == "image"
      max_targets_len_from_inputs = targets
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, inputs, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck(
          x=targets_c,
          filter_size=hparams.compress_filter_size,
          name="vc",
          mode=hparams.mode)
      if _DO_SUMMARIES:
        tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps)
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([batch_size]), pc)
      latents_dense = tf.where(cond, latents_dense, targets_c)
      # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
      losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        latents_pred = decode_transformer(
            inputs_ex, ed_ex,
            embed(latents_discrete), hparams, "extra",
            task="translate")
        _, latent_pred_loss = ae_latent_softmax(
            latents_pred, tf.stop_gradient(latents_discrete), hparams)
        losses["latent_pred"] = tf.reduce_mean(
            latent_pred_loss * tf.to_float(cond))
      else:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20
        def bn_inputs():
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            bn, _, _, _ = hparams.bottleneck(
                x=inputs_c,
                filter_size=hparams.compress_filter_size,
                name="vc",
                mode=hparams.mode)
          return bn
        inputs_c = bn_inputs
        ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
        ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
                                 latents_dense, inputs_c)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        latents_dense, _, _, _ = hparams.bottleneck(
            x=inputs_c,
            filter_size=hparams.compress_filter_size,
            name="vc",
            mode=hparams.mode)
      else:
        latent_len = common_layers.shape_list(targets_c)[1]
        _, _, _, embed = hparams.bottleneck(
            x=targets_c, filter_size=hparams.compress_filter_size, name="vc")
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(
              latents_dense, inputs_ex, ed_ex, embed, 16, hparams)
        latents_dense = embed(cache)
    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # decompressing the dense latents
    for i in range(hparams.num_compress_steps):
      j = hparams.num_compress_steps - i - 1
      d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
      if hparams.do_attend_decompress:
        d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
      d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
      masking *= common_layers.inverse_exp_decay(
          hparams.mask_startup_steps // 4)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * hparams.unmasked_percentage
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.use_predict_mask:
        masking = predict_mask
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(
          common_layers.shape_list(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)

      # targets is always [batch, length, 1, depth]
      targets = mask * targets + (1.0 - mask) * d
      # reshape back to 4d here
      if hparams.task == "image":
        targets = tf.reshape(targets, original_targets_shape)
    if hparams.task == "translate":
      targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder",
                           causal=hparams.causal)
  if hparams.do_ae:
    if hparams.task == "translate":
      res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        # return residual_conv(res, 1, (5, 1), hparams, "refine")
        r, _ = encode(tf.squeeze(res, axis=[2]),
                      target_space, hparams, "refine_enc")
        return tf.expand_dims(r, axis=2)
      masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
      all_masked = tf.less(masked_batches, 0.1)
      res = tf.where(all_masked, refine_res(), res)
    # We'll start training the extra model of latents after mask_startup_steps.
    nonlatent_steps = hparams.mask_startup_steps
    latent_time = tf.less(nonlatent_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
  return res, losses, cache
Ejemplo n.º 51
0
    def render2cmd_v3_internal(self, features, hparams, train):
        # inputs and targets are both sequences with
        # shape = [batch, seq_len, 1, hparams.problem.feature_dim]
        targets = features['targets']
        source = features['source']

        losses = {}
        sampled_bottleneck = self.pretrained_visual_encoder(features, hparams)

        if hparams.sg_bottleneck:
            sampled_bottleneck = tf.stop_gradient(sampled_bottleneck)

        with tf.variable_scope('render2cmd_v3_internal'):
            # override bottleneck, or return it, if requested
            if 'bottleneck' in features:
                if common_layers.shape_list(features['bottleneck'])[0] == 0:
                    # return sampled_bottleneck,
                    # set losses['training'] = 0 so self.top() doesn't get called on it
                    return sampled_bottleneck, {'training': 0.0}
                else:
                    # we want to use the given bottleneck
                    sampled_bottleneck = features['bottleneck']

            # finalize bottleneck
            unbottleneck_dim = hparams.hidden_size * 2  # twice because using LSTM
            if hparams.twice_decoder:
                unbottleneck_dim = unbottleneck_dim * 2

            dec_initial_state = []

            # LSTM encoder
            _, encoder_output_states = self.lstm_encoder(
                common_layers.flatten4d3d(source), hparams)
            print(features['targets'].shape)
            print('run stacking...')
            print(sampled_bottleneck.shape)
            print(source.shape)
            # input()
            for hi in range(hparams.num_hidden_layers):
                unbottleneck = self.unbottleneck(sampled_bottleneck, unbottleneck_dim,
                                                 name_append='_{}'.format(hi))
                c, h = encoder_output_states[hi]
                # print(unbottleneck.shape)
                #print(c.shape, h.shape)
                first_dim = common_layers.shape_list(unbottleneck)[0]
                # print(first_dim)
                #c = tf.tile(c,[first_dim,1])
                #h = tf.tile(h,[first_dim,1])
                # input()
                dec_initial_state.append(
                    tf.nn.rnn_cell.LSTMStateTuple(
                        c=tf.concat(
                            [unbottleneck[:, :unbottleneck_dim // 2], c], 1),
                        h=tf.concat([unbottleneck[:, unbottleneck_dim // 2:], h], 1)))

            dec_initial_state = tuple(dec_initial_state)
            #print('checkshape dec_initial_state')
            # print(dec_initial_state)
            # input()
            shifted_targets = common_layers.shift_right(targets)
            # Add 1 to account for the padding added to the left from shift_right
            targets_length = common_layers.length_from_embedding(
                shifted_targets) + 1

            # LSTM decoder
            hparams_decoder = copy.copy(hparams)
            if hparams.twice_decoder:
                hparams_decoder.hidden_size = 2 * hparams.hidden_size

            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                decoder_outputs, _ = self.lstm_decoder_infer(
                    common_layers.flatten4d3d(shifted_targets),
                    targets_length, hparams_decoder, features['targets_cls'],
                    train, initial_state=dec_initial_state,
                    bottleneck=sampled_bottleneck)
            else:
                decoder_outputs, _ = self.lstm_decoder(
                    common_layers.flatten4d3d(shifted_targets),
                    targets_length, hparams_decoder, features['targets_cls'],
                    train, initial_state=dec_initial_state,
                    bottleneck=sampled_bottleneck)

            ret = tf.expand_dims(decoder_outputs, axis=2)
        return ret, losses
Ejemplo n.º 52
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    vocab_size = self._problem_hparams.target_modality.top_dimensionality
    encoder_layers = None
    self.is1d = hparams.sample_width == 1
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      labels = features["targets_raw"]
      labels_shape = common_layers.shape_list(labels)
      # handle videos
      if len(labels.shape) == 5:
        labels = common_layers.time_to_channels(labels)
      shape = common_layers.shape_list(labels)
      x = tf.one_hot(labels, vocab_size)
      x = self.embed(x)
      target_codes = x
      if shape[2] == 1:
        self.is1d = True
      # Run encoder.
      x, encoder_layers = self.encoder(x)
      # Bottleneck.
      b, b_loss = self.bottleneck(x)
      xb_loss = 0.0
      b_shape = common_layers.shape_list(b)
      self._cur_bottleneck_tensor = b
      b = self.unbottleneck(b, common_layers.shape_list(x)[-1])
      if not is_training:
        x = b
      else:
        l = 2**hparams.num_hidden_layers
        warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l)
        nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01
        if common_layers.should_generate_summaries():
          tf.summary.scalar("nomix_p_bottleneck", nomix_p)
        rand = tf.random_uniform(common_layers.shape_list(x))
        # This is the distance between b and x. Having this as loss helps learn
        # the bottleneck function, but if we back-propagated to x it would be
        # minimized by just setting x=0 and b=0 -- so we don't want too much
        # of the influence of this, and we stop-gradient to not zero-out x.
        x_stop = tf.stop_gradient(x)
        xb_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x_stop - b), axis=-1))
        # To prevent this loss from exploding we clip at 1, but anneal clipping.
        clip_max = 1.0 / common_layers.inverse_exp_decay(
            warm_step, min_value=0.001)
        xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max)
        xb_loss *= clip_max / xb_clip
        x = tf.where(tf.less(rand, nomix_p), b, x)
      if hparams.gan_loss_factor != 0.0:
        # Add a purely sampled batch on which we'll compute the GAN loss.
        g = self.unbottleneck(
            self.sample(shape=b_shape),
            common_layers.shape_list(x)[-1],
            reuse=True)
        x = tf.concat([g, x], axis=0)
        encoder_layers = [tf.concat([l, l], axis=0) for l in encoder_layers]
    else:
      if self._cur_bottleneck_tensor is None:
        b = self.sample()
      else:
        b = self._cur_bottleneck_tensor
      self._cur_bottleneck_tensor = b
      res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers
      res_size = min(res_size, hparams.max_hidden_size)
      x = self.unbottleneck(b, res_size)
    # Run decoder.
    x = self.decoder(x, encoder_layers)

    # Cut to the right size and mix before returning.
    res = x
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      res = x[:, :shape[1], :shape[2], :]

    # Final dense layer.
    res = tf.layers.dense(
        res, self.num_channels * hparams.hidden_size, name="res_dense")

    output_shape = common_layers.shape_list(res)[:-1] + [
        self.num_channels, self.hparams.hidden_size
    ]
    res = tf.reshape(res, output_shape)

    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hparams.use_vq_loss:
        (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size)
      else:
        reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      return reconstr, {"bottleneck_loss": 0.0}

    if hparams.gan_loss_factor != 0.0:
      res_gan, res = tf.split(res, 2, axis=0)

    # Losses.
    losses = {
        "bottleneck_extra": b_loss,
        "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss
    }

    if hparams.use_vq_loss:
      vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.2,
          min_value=hparams.vq_temperature * 2)
      if hparams.mode != tf.estimator.ModeKeys.TRAIN:
        vq_temperature = None
      with tf.variable_scope("vq_loss"):
        (reconstr, _, target_codes, code_loss,
         targets_loss) = discretization.vq_loss(
             res, labels, vocab_size, temperature=vq_temperature)
      losses["code_loss"] = code_loss * hparams.code_loss_factor
      losses["training"] = targets_loss
    else:
      reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final")
      targets_loss = tf.losses.sparse_softmax_cross_entropy(
          logits=tf.reshape(reconstr, labels_shape + [vocab_size]),
          labels=tf.reshape(labels, labels_shape))
      losses["training"] = targets_loss

    # GAN losses.
    if hparams.gan_loss_factor != 0.0:
      update_means_factor = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps, min_value=0.0001)
      if hparams.use_vq_loss:
        with tf.variable_scope("vq_loss", reuse=True):
          update_means = tf.less(tf.random_uniform([]), update_means_factor)
          reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss(
              res_gan,
              labels,
              vocab_size,
              do_update=update_means,
              temperature=vq_temperature)
          reconstr_gan_nonoise = reconstr_gan
          code_loss_gan *= hparams.code_loss_factor * update_means_factor
          losses["code_loss_gan"] = code_loss_gan
      else:
        reconstr_gan = tf.layers.dense(
            res_gan, vocab_size, name="autoencoder_final", reuse=True)
        reconstr_gan_nonoise = reconstr_gan
        reconstr_gan = self.gumbel_sample(reconstr_gan)
        # Embed to codes.
        gan_codes = self.embed(reconstr_gan)

    # Add GAN loss if requested.
    gan_loss = 0.0
    if hparams.gan_loss_factor != 0.0:
      self.image_summary("gan", reconstr_gan_nonoise)

      def discriminate(x):
        """Run a dioscriminator depending on the hparams."""
        if hparams.discriminator == "default":
          return common_layers.deep_discriminator(
              x, hparams.discriminator_batchnorm, is_training)
        elif hparams.discriminator == "patched":
          return common_layers.patch_discriminator(x)
        elif hparams.discriminator == "single":
          return common_layers.single_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        elif hparams.discriminator == "double":
          return common_layers.double_discriminator(
              x,
              hparams.discriminator_size,
              hparams.discriminator_kernel_size,
              hparams.discriminator_strides,
              pure_mean=hparams.discriminator_pure_mean)
        else:
          raise Exception("Unknown discriminator %s" % hparams.discriminator)

      tc_shape = common_layers.shape_list(target_codes)
      if len(tc_shape) > 4:
        target_codes = tf.reshape(target_codes,
                                  tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
        gan_codes = tf.reshape(gan_codes,
                               tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]])
      gan_lr = common_layers.inverse_exp_decay(
          hparams.gan_codes_warmup_steps * 1.5)
      rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr)
      gan_loss = common_layers.sliced_gan_loss(
          target_codes,
          rev_grad_gan_codes,
          discriminate,
          self.hparams.num_sliced_vecs,
          do_tanh=hparams.sliced_do_tanh)
      gan_loss *= hparams.gan_loss_factor * update_means_factor
      losses["gan_loss"] = -gan_loss

    self.image_summary("ae", reconstr)

    logits = tf.reshape(reconstr, labels_shape + [vocab_size])
    return logits, losses
Ejemplo n.º 53
0
def invertible_1x1_conv(name, x, reverse=False):
  """1X1 convolution on x.

  The 1X1 convolution is parametrized as P*L*(U + sign(s)*exp(log(s))) where
  1. P is a permutation matrix.
  2. L is a lower triangular matrix with diagonal entries unity.
  3. U is a upper triangular matrix where the diagonal entries zero.
  4. s is a vector.

  sign(s) and P are fixed and the remaining are optimized. P, L, U and s are
  initialized by the PLU decomposition of a random rotation matrix.

  Args:
    name: scope
    x: Input Tensor.
    reverse: whether the pass is from z -> x or x -> z.

  Returns:
    x_conv: x after a 1X1 convolution is applied on x.
    objective: sum(log(s))
  """
  _, height, width, channels = common_layers.shape_list(x)
  w_shape = [channels, channels]

  # Random rotation-matrix Q
  random_matrix = np.random.rand(channels, channels)
  np_w = scipy.linalg.qr(random_matrix)[0].astype("float32")

  # Initialize P,L,U and s from the LU decomposition of a random rotation matrix
  np_p, np_l, np_u = scipy.linalg.lu(np_w)
  np_s = np.diag(np_u)
  np_sign_s = np.sign(np_s)
  np_log_s = np.log(np.abs(np_s))
  np_u = np.triu(np_u, k=1)

  with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
    p = tf.get_variable("P", initializer=np_p, trainable=False)
    l = tf.get_variable("L", initializer=np_l)
    sign_s = tf.get_variable(
        "sign_S", initializer=np_sign_s, trainable=False)
    log_s = tf.get_variable("log_S", initializer=np_log_s)
    u = tf.get_variable("U", initializer=np_u)

    # W = P * L * (U + sign_s * exp(log_s))
    l_mask = np.tril(np.ones([channels, channels], dtype=np.float32), -1)
    l = l * l_mask + tf.eye(channels, channels)
    u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s))
    w = tf.matmul(p, tf.matmul(l, u))

    # If height or width cannot be statically determined then they end up as
    # tf.int32 tensors, which cannot be directly multiplied with a floating
    # point tensor without a cast.
    objective = tf.reduce_sum(log_s) * tf.cast(height * width, log_s.dtype)
    if not reverse:
      w = tf.reshape(w, [1, 1] + w_shape)
      x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format="NHWC")
    else:
      # TODO(b/111271662): Remove when supported.
      def tpu_inv(m):
        """tf.linalg.inv workaround until it is supported on TPU."""
        q, r = tf.linalg.qr(m)
        return tf.linalg.triangular_solve(r, tf.transpose(q), lower=False)
      w_inv = tf.reshape(tpu_inv(w), [1, 1]+w_shape)
      x = tf.nn.conv2d(
          x, w_inv, [1, 1, 1, 1], "SAME", data_format="NHWC")
      objective *= -1
  return x, objective
Ejemplo n.º 54
0
 def while_exit_cond(logits_so_far, unused_current_hidden):
     length = common_layers.shape_list(logits_so_far)[1]
     return length < max_decode_length
Ejemplo n.º 55
0
    def lstm_decoder_infer(self, inputs, sequence_length, hparams, clss, train,
                           initial_state=None, bottleneck=None):
        # IN PREDICT MODE, RUN tf.while RNN
        max_decode_length = 51
        batch_size = common_layers.shape_list(inputs)[0]
        zero_pad, logits_so_far = self.create_initial_input_for_decode(
            batch_size)

        layers = contrib_rnn.MultiRNNCell([
            self.lstm_cell(hparams, train) for _ in range(hparams.num_hidden_layers)
        ])

        if initial_state is None:
            raise Exception('initial state should be init from bottleneck!')

        # append one-hot class to bottleneck, which will be given per step
        clss = tf.reshape(clss, [-1])
        if not hparams.use_cls:
            clss = tf.zeros_like(clss)
        if hparams.condition_on_sln:
            sln = tf.reshape(sequence_length, [-1])
            bottleneck = tf.concat((bottleneck,
                                    tf.one_hot(clss, hparams.num_categories),
                                    tf.one_hot(sln, max_decode_length)), -1)
        else:
            bottleneck = tf.concat((bottleneck,
                                    tf.one_hot(clss, hparams.num_categories)), -1)

        def infer_step(logits_so_far, current_hidden):
            """Inference step of LSTM while loop."""
            # unflatten hidden:
            current_hidden = tuple(tf.nn.rnn_cell.LSTMStateTuple(c=s[0], h=s[1])
                                   for s in current_hidden)

            # put logits_so_far through top
            tm = self._problem_hparams.modality['targets']
            # need to reuse top params
            reset_scope = tf.variable_scope(tf.VariableScope(tf.AUTO_REUSE, ''),
                                            reuse=tf.AUTO_REUSE,
                                            auxiliary_name_scope=False)
            top_scope = tf.variable_scope('svg_decoder/{}_modality'.format(tm),
                                          reuse=tf.AUTO_REUSE)
            with reset_scope, top_scope:
                samples_so_far = self.hparams.top['targets'](
                    logits_so_far, None, self.hparams, self.problem_hparams.vocab_size)
            # append a zero pad to the samples. this effectively shifts the samples
            # right, but, unlike shift_right, by not removing the last element, we
            # allow an empty samples_so_far to not be empty after padding
            samples_so_far = tf.concat([zero_pad, samples_so_far], axis=1)
            shifted_targets = common_layers.flatten4d3d(samples_so_far)
            # now take the very last one here, will be the actual input to the rnn
            shifted_targets = shifted_targets[:, -1:, :]

            # tile and append the bottleneck to inputs
            sln_offset = 0
            if hparams.condition_on_sln:
                sln_offset = 51
            pre_tile_y = tf.reshape(
                bottleneck,
                [common_layers.shape_list(bottleneck)[0], 1,
                 hparams.bottleneck_bits + hparams.num_categories + sln_offset])
            overlay_x = tf.tile(pre_tile_y,
                                [1, common_layers.shape_list(shifted_targets)[1], 1])
            inputs = tf.concat([shifted_targets, overlay_x], -1)

            seq_len_batch = tf.ones([common_layers.shape_list(inputs)[0]])

            # RUN PRE-LSTM LAYER
            with tf.variable_scope('pre_decoder', reuse=tf.AUTO_REUSE):
                inputs = tf.layers.dense(
                    inputs, hparams.hidden_size, name='bottom')
                inputs = tf.nn.tanh(inputs)

            # RUN LSTM
            with tf.variable_scope('lstm_decoder', reuse=tf.AUTO_REUSE):
                next_step, next_state = tf.nn.dynamic_rnn(
                    layers, inputs, seq_len_batch, initial_state=current_hidden,
                    dtype=tf.float32, time_major=False)

            next_step = tf.expand_dims(next_step, [1])

            logits_so_far = tf.concat([logits_so_far, next_step], 1)
            #print('concat success')
            # input()
            # flatten state
            next_state = tuple((s.c, s.h) for s in next_state)

            return logits_so_far, next_state

        def while_exit_cond(logits_so_far, unused_current_hidden):
            length = common_layers.shape_list(logits_so_far)[1]
            return length < max_decode_length

        # passing state must be flattened:
        initial_state = tuple([(s.c, s.h) for s in initial_state])

        # actually run tf.while:
        logits, final_state = tf.while_loop(
            while_exit_cond, infer_step,
            [logits_so_far, initial_state],
            shape_invariants=[
                tf.TensorShape([None, None, 1, hparams.hidden_size]),
                tuple([(s[0].get_shape(), s[1].get_shape())
                       for s in initial_state]),
            ],
            back_prop=False,
            parallel_iterations=1
        )

        # logits should be returned in 3d mode:
        logits = common_layers.flatten4d3d(logits)

        return logits, final_state
Ejemplo n.º 56
0
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
  """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
  ishape_static = inputs.shape.as_list()
  encoder_input = inputs
  if features and "inputs_segmentation" in features:
    # Packed dataset.  Keep the examples from seeing each other.
    inputs_segmentation = features["inputs_segmentation"]
    inputs_position = features["inputs_position"]
    targets_segmentation = features["targets_segmentation"]
    encoder_self_attention_bias = common_attention.attention_bias_same_segment(
        inputs_segmentation, inputs_segmentation)
    encoder_decoder_attention_bias = (
        common_attention.attention_bias_same_segment(targets_segmentation,
                                                     inputs_segmentation))
  else:
    # Usual case - not a packed dataset.
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    inputs_position = None
  if hparams.proximity_bias:
    encoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(inputs)[1])
  if hparams.get("use_target_space_embedding", True):
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(
        target_space,
        32,
        ishape_static[-1],
        name="target_space_embedding",
        dtype=tf.bfloat16
        if hparams.activation_dtype == "bfloat16" else tf.float32)
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
  if hparams.pos == "timing":
    if inputs_position is not None:
      encoder_input = common_attention.add_timing_signal_1d_given_position(
          encoder_input, inputs_position)
    else:
      encoder_input = common_attention.add_timing_signal_1d(encoder_input)
  elif hparams.pos == "emb":
    encoder_input = common_attention.add_positional_embedding(
        encoder_input, hparams.max_length, "inputs_positional_embedding",
        inputs_position)
  if hparams.activation_dtype == "bfloat16":
    encoder_self_attention_bias = tf.cast(encoder_self_attention_bias,
                                          tf.bfloat16)
    encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
                                             tf.bfloat16)
  return (encoder_input, encoder_self_attention_bias,
          encoder_decoder_attention_bias)
Ejemplo n.º 57
0
    def body(self, features):
        hparams = self.hparams
        batch_size = common_layers.shape_list(features["inputs"])[0]

        # Swap time and batch axes.
        input_frames = common_video.swap_time_and_batch_axes(
            features["inputs"])
        target_frames = common_video.swap_time_and_batch_axes(
            features["targets"])

        # Get actions if exist otherwise use zeros
        input_actions = self.get_input_if_exists(
            features, "input_action", batch_size,
            hparams.video_num_input_frames)
        target_actions = self.get_input_if_exists(
            features, "target_action", batch_size,
            hparams.video_num_target_frames)

        # Get rewards if exist otherwise use zeros
        input_rewards = self.get_input_if_exists(
            features, "input_reward", batch_size,
            hparams.video_num_input_frames)
        target_rewards = self.get_input_if_exists(
            features, "target_reward", batch_size,
            hparams.video_num_target_frames)

        all_actions = tf.concat([input_actions, target_actions], axis=0)
        all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
        all_frames = tf.concat([input_frames, target_frames], axis=0)

        # Each image is being used twice, in latent tower and main tower.
        # This is to make sure we are using the *same* image for both, ...
        # ... given how TF queues work.
        # NOT sure if this is required at all. Doesn"t hurt though! :)
        all_frames = tf.identity(all_frames)

        gen_images, gen_rewards, latent_means, latent_stds = self.construct_model(
            images=all_frames,
            actions=all_actions,
            rewards=all_rewards,
        )

        extra_loss = self.get_extra_loss(latent_means=latent_means,
                                         latent_stds=latent_stds,
                                         true_frames=all_frames,
                                         gen_frames=gen_images)

        # Visualize predictions in Tensorboard
        if self.is_training and not self.is_per_pixel_softmax:
            self.visualize_predictions(all_frames[1:], gen_images)

        # Ignore the predictions from the input frames.
        # This is NOT the same as original paper/implementation.
        predictions = gen_images[hparams.video_num_input_frames - 1:]
        reward_pred = gen_rewards[hparams.video_num_input_frames - 1:]
        reward_pred = tf.squeeze(reward_pred,
                                 axis=2)  # Remove extra dimension.

        # Swap back time and batch axes.
        predictions = common_video.swap_time_and_batch_axes(predictions)
        reward_pred = common_video.swap_time_and_batch_axes(reward_pred)

        if hparams.internal_loss:
            # add the MSE loss for input frames as well.
            # we are assuming the modality is L2. otherwise the loss would be
            # incosistent across the frames.
            if self._target_modality != "VideoModalityL2Raw":
                raise ValueError("internal loss only works with L2.")
            recon_loss = tf.losses.mean_squared_error(
                all_frames[1:hparams.video_num_input_frames + 1],
                gen_images[:hparams.video_num_input_frames])
            tf.summary.scalar("mse_extra", recon_loss)
            extra_loss += recon_loss

        return_targets = predictions
        if hparams.reward_prediction:
            return_targets = {
                "targets": predictions,
                "target_reward": reward_pred
            }

        return return_targets, extra_loss
Ejemplo n.º 58
0
def transformer_ffn_layer(x,
                          hparams,
                          pad_remover=None,
                          conv_padding="LEFT",
                          nonpadding_mask=None,
                          losses=None,
                          cache=None,
                          decode_loop_step=None,
                          readout_filter_size=0):
  """Feed-forward layer in the transformer.

  Args:
    x: a Tensor of shape [batch_size, length, hparams.hidden_size]
    hparams: hyperparameters for model
    pad_remover: an expert_utils.PadRemover object tracking the padding
      positions. If provided, when using convolutional settings, the padding
      is removed before applying the convolution, and restored afterward. This
      can give a significant speedup.
    conv_padding: a string - either "LEFT" or "SAME".
    nonpadding_mask: an optional Tensor with shape [batch_size, length].
      needed for convolutional layers with "SAME" padding.
      Contains 1.0 in positions corresponding to nonpadding.
    losses: optional list onto which to append extra training losses
    cache: dict, containing tensors which are the results of previous
        attentions, used for fast decoding.
    decode_loop_step: An integer, step number of the decoding loop.
        Only used for inference on TPU.
    readout_filter_size: if it's greater than 0, then it will be used instead of
      filter_size


  Returns:
    a Tensor of shape [batch_size, length, hparams.hidden_size]

  Raises:
    ValueError: If losses arg is None, but layer generates extra losses.
  """
  ffn_layer = hparams.ffn_layer
  relu_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "relu_dropout_broadcast_dims", "")))
  if ffn_layer == "conv_hidden_relu":
    # Backwards compatibility
    ffn_layer = "dense_relu_dense"
  if ffn_layer == "dense_relu_dense":
    # In simple convolution mode, use `pad_remover` to speed up processing.
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE,
        value={
            "filter_size": hparams.filter_size,
            "use_bias": "True",
            "activation": mlperf_log.RELU
        })
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE,
        value={
            "hidden_size": hparams.hidden_size,
            "use_bias": "True",
        })
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout)
    if pad_remover:
      original_shape = common_layers.shape_list(x)
      # Collapse `x` across examples, and remove padding positions.
      x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0))
      x = tf.expand_dims(pad_remover.remove(x), axis=0)
    conv_output = quaternion_dense_relu_dense(
        x,
        hparams.filter_size,
        hparams.hidden_size,
        dropout=hparams.relu_dropout,
        dropout_broadcast_dims=relu_dropout_broadcast_dims)
    if pad_remover:
      # Restore `conv_output` to the original shape of `x`, including padding.
      conv_output = tf.reshape(
          pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape)
    return conv_output
  elif ffn_layer == "raw_dense_relu_dense":
    # In simple convolution mode, use `pad_remover` to speed up processing.
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE,
        value={
            "filter_size": hparams.filter_size,
            "use_bias": "True",
            "activation": mlperf_log.RELU
        })
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE,
        value={
            "hidden_size": hparams.hidden_size,
            "use_bias": "True",
        })
    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout)
    if pad_remover:
      original_shape = common_layers.shape_list(x)
      # Collapse `x` across examples, and remove padding positions.
      x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0))
      x = tf.expand_dims(pad_remover.remove(x), axis=0)
    conv_output = common_layers.dense_relu_dense(
        x,
        hparams.filter_size,
        hparams.hidden_size,
        dropout=hparams.relu_dropout,
        dropout_broadcast_dims=relu_dropout_broadcast_dims)
    if pad_remover:
      # Restore `conv_output` to the original shape of `x`, including padding.
      conv_output = tf.reshape(
          pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape)
    return conv_output
  elif ffn_layer == "conv_relu_conv":
    return common_layers.conv_relu_conv(
        x,
        readout_filter_size or hparams.filter_size,
        hparams.hidden_size,
        first_kernel_size=hparams.conv_first_kernel,
        second_kernel_size=1,
        padding=conv_padding,
        nonpadding_mask=nonpadding_mask,
        dropout=hparams.relu_dropout,
        cache=cache,
        decode_loop_step=decode_loop_step)
  elif ffn_layer == "parameter_attention":
    return common_attention.parameter_attention(
        x, hparams.parameter_attention_key_channels or hparams.hidden_size,
        hparams.parameter_attention_value_channels or hparams.hidden_size,
        hparams.hidden_size, readout_filter_size or hparams.filter_size,
        hparams.num_heads,
        hparams.attention_dropout)
  elif ffn_layer == "conv_hidden_relu_with_sepconv":
    return common_layers.conv_hidden_relu(
        x,
        readout_filter_size or hparams.filter_size,
        hparams.hidden_size,
        kernel_size=(3, 1),
        second_kernel_size=(31, 1),
        padding="LEFT",
        dropout=hparams.relu_dropout)
  elif ffn_layer == "sru":
    return common_layers.sru(x)
  elif ffn_layer == "local_moe_tpu":
    overhead = (
        hparams.moe_overhead_train
        if hparams.mode == tf.estimator.ModeKeys.TRAIN else
        hparams.moe_overhead_eval)
    ret, loss = expert_utils.local_moe_tpu(
        x,
        hparams.filter_size // 2,
        hparams.hidden_size,
        hparams.moe_num_experts,
        overhead=overhead,
        loss_coef=hparams.moe_loss_coef)
  elif ffn_layer == "local_moe":
    overhead = (
        hparams.moe_overhead_train
        if hparams.mode == tf.estimator.ModeKeys.TRAIN else
        hparams.moe_overhead_eval)
    ret, loss = expert_utils.local_moe(
        x,
        True,
        expert_utils.ffn_expert_fn(hparams.hidden_size, [hparams.filter_size],
                                   hparams.hidden_size),
        hparams.moe_num_experts,
        k=hparams.moe_k,
        hparams=hparams)
    losses.append(loss)
    return ret
  else:
    assert ffn_layer == "none"
    return x
Ejemplo n.º 59
0
 def variance_loss(self, b):
   part = tf.random_uniform(common_layers.shape_list(b))
   selection = tf.to_float(tf.less(part, tf.random_uniform([])))
   selection_size = tf.reduce_sum(selection)
   part_avg = tf.abs(tf.reduce_sum(b * selection)) / (selection_size + 1)
   return part_avg
Ejemplo n.º 60
0
    def body(self, features):
        hparams = self.hparams
        is_predicting = hparams.mode == tf.estimator.ModeKeys.PREDICT

        # TODO(lukaszkaiser): the split axes and the argmax below heavily depend on
        # using the default (a bit strange) video modality - we should change that.

        # Split inputs and targets into lists.
        input_frames = tf.unstack(features["inputs"], axis=1)
        target_frames = tf.unstack(features["targets"], axis=1)
        all_frames = input_frames + target_frames
        if "input_action" in features:
            input_actions = list(
                tf.split(features["input_action"],
                         hparams.video_num_input_frames,
                         axis=1))
            target_actions = list(
                tf.split(features["target_action"],
                         hparams.video_num_target_frames,
                         axis=1))
            all_actions = input_actions + target_actions

        orig_frame_shape = common_layers.shape_list(all_frames[0])

        # Run a number of steps.
        res_frames, sampled_frames, sampled_frames_raw = [], [], []
        if "target_reward" in features:
            res_rewards, extra_loss = [], 0.0
        sample_prob = common_layers.inverse_exp_decay(
            hparams.scheduled_sampling_warmup_steps)
        sample_prob *= hparams.scheduled_sampling_prob
        for i in range(hparams.video_num_target_frames):
            cur_frames = all_frames[i:i + hparams.video_num_input_frames]
            features["inputs"] = tf.concat(cur_frames, axis=-1)
            features["cur_target_frame"] = all_frames[
                i + hparams.video_num_input_frames]
            if "input_action" in features:
                cur_actions = all_actions[i:i + hparams.video_num_input_frames]
                features["input_action"] = tf.concat(cur_actions, axis=1)

            # Run model.
            with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0):
                if "target_reward" not in features:
                    res_frame = self.body_single(features)
                else:
                    res_dict, res_extra_loss = self.body_single(features)
                    extra_loss += res_extra_loss
                    res_frame = res_dict["targets"]
                    res_reward = res_dict["target_reward"]
                    res_rewards.append(res_reward)
            res_frames.append(res_frame)

            # Only for Softmax loss: sample frame so we can keep iterating.
            sampled_frame_raw = self.get_sampled_frame(res_frame)
            sampled_frames_raw.append(sampled_frame_raw)
            # TODO(lukaszkaiser): this should be consistent with modality.bottom()
            sampled_frame = common_layers.standardize_images(sampled_frame_raw)
            sampled_frames.append(sampled_frame)

            if is_predicting:
                all_frames[i + hparams.video_num_input_frames] = sampled_frame

            # Scheduled sampling during training.
            if (hparams.scheduled_sampling_prob > 0.0 and self.is_training):
                do_sample = tf.less(tf.random_uniform([orig_frame_shape[0]]),
                                    sample_prob)
                orig_frame = all_frames[i + hparams.video_num_input_frames]
                sampled_frame = tf.where(do_sample, sampled_frame, orig_frame)
                all_frames[i + hparams.video_num_input_frames] = sampled_frame

        # Concatenate results and return them.
        frames = tf.stack(res_frames, axis=1)

        if "target_reward" not in features:
            return frames
        rewards = tf.concat(res_rewards, axis=1)
        return {"targets": frames, "target_reward": rewards}, extra_loss