Exemplo n.º 1
0
def prepare_decoder(targets, hparams):
  """Prepare decoder for images."""
  targets_shape = common_layers.shape_list(targets)
  channels = hparams.num_channels
  curr_infer_length = None

  # during training, images are [batch, IMG_LEN, IMG_LEN, 3].
  # At inference, they are [batch, curr_infer_length, 1, 1]
  if hparams.mode == tf.estimator.ModeKeys.PREDICT:
    curr_infer_length = targets_shape[1]
    if hparams.block_raster_scan:
      assert hparams.img_len*channels % hparams.query_shape[1] == 0
      assert hparams.img_len % hparams.query_shape[0] == 0
      total_block_width = hparams.img_len*channels
      # Decoding is in block raster scan order. We divide the image into
      # hparams.query_shape blocks and then decode each block in raster scan.
      # To make that compatible with our inference pipeline, pad the target so
      # that rows is a multiple of query_shape and columns is a multiple of
      # hparams.img_len*channels
      curr_infer_length = targets_shape[1]
      block_padding_factor = total_block_width * hparams.query_shape[0]
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % block_padding_factor],
          [0, 0], [0, 0]])

      num_blocks = total_block_width // hparams.query_shape[1]
      # Reshape the image to represent blocks
      target_blocks = tf.reshape(
          targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0],
                    hparams.query_shape[1]])
      # Transpose to read the image in 2D fashion.
      targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
    else:
      # add padding to make sure the size of targets is a multiple of img_height
      # times number of channels. This is  needed for positional encodings and
      # for doing the RGB lookup.
      padding_factor = channels * hparams.img_len
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]])
    targets = tf.reshape(targets,
                         [targets_shape[0], -1, hparams.img_len, channels])
  # Preprocess image
  x = prepare_image(targets, hparams, name="dec_channels")
  x_shape = common_layers.shape_list(x)
  if (hparams.dec_attention_type == AttentionType.LOCAL_2D or
      hparams.dec_attention_type == AttentionType.LOCAL_BLOCK):
    x = common_attention.right_shift_blockwise(x, hparams.query_shape)
    x = add_pos_signals(x, hparams, "dec_pos")
  else:
    # Add position signals
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1]*x_shape[2], hparams.hidden_size])
    x = common_layers.shift_right_3d(x)
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1], x_shape[2], hparams.hidden_size])
    x = add_pos_signals(x, hparams, "dec_pos")
  x = common_layers.cast_like(x, targets)
  return x, x_shape[1], x_shape[2]
Exemplo n.º 2
0
def postprocess_image(x, rows, cols, hparams):
  """Postprocessing after decoding.

  Args:
    x: Tensor of shape [batch, ...], where ... can be any rank such that the
      number of elements in x is batch * rows * cols * hparams.hidden_size.
    rows: Integer representing number of rows in a 2-D data point.
    cols: Integer representing number of columns in a 2-D data point.
    hparams: HParams set.

  Returns:
    Tensor of shape [batch, rows, cols, depth], where depth is
    hparams.num_mixtures * 10 if hparams.likelihood is DMOL, otherwise 256. In
    the special case of inference and block raster scan order, it is a Tensor
    of shape [batch, num_blocks_rows, num_block_cols, block_length, block_width,
    depth].
  """
  batch = common_layers.shape_list(x)[0]
  x = tf.reshape(x, [batch, rows, cols, hparams.hidden_size])
  likelihood = getattr(hparams, "likelihood", DistributionType.CAT)
  if likelihood == DistributionType.DMOL:
    depth = hparams.num_mixtures * 10
    targets = tf.layers.dense(x,
                              depth,
                              use_bias=False,
                              activation=None,
                              name="output_conv")
  else:
    depth = 256
    targets = tf.layers.dense(x,
                              depth,
                              use_bias=True,
                              activation=None,
                              name="output_conv")
  if (hparams.mode == tf.estimator.ModeKeys.PREDICT and
      hparams.block_raster_scan):
    y = targets
    yshape = common_layers.shape_list(y)
    block_length = hparams.query_shape[0]
    block_width = hparams.query_shape[1]

    # Break into block row wise.
    y = tf.reshape(y,
                   [batch, yshape[1] // block_length, block_length,
                    yshape[2], depth])
    yshape = common_layers.shape_list(y)
    # Break into blocks width wise.
    y_blocks = tf.reshape(y,
                          [batch, yshape[1], yshape[2],
                           yshape[3] // block_width, block_width, depth])

    # Reshape targets as [batch, num_blocks_rows, num_block_cols, block_length,
    # block_width, depth].
    targets = tf.transpose(y_blocks, [0, 1, 3, 2, 4, 5])

  return targets
Exemplo n.º 3
0
def prepare_image(inputs, hparams, name=None):
  """Prepare image."""
  inputs_shape = common_layers.shape_list(inputs)
  batch = inputs_shape[0]
  orig_rows = inputs_shape[1]
  orig_cols = inputs_shape[2]
  channels = hparams.num_channels

  hidden_size = hparams.hidden_size
  # TODO(trandustin): Check via modalities.ModalityType.IDENTITY and not str.
  # The current implementation is to avoid circular imports, modalities ->
  # discretization -> common_image_attention -> modalities.
  if "targets" in hparams.modality:
    target_modality_name = hparams.modality["targets"]
    if not isinstance(target_modality_name, str):
      target_modality_name = target_modality_name.__name__
  else:
    target_modality_name = None
  if target_modality_name == "IdentityModality":
    inputs = tf.to_int32(inputs)
    x = get_channel_embeddings(channels, inputs, hidden_size, name=name)
  else:
    x = inputs
  x = tf.reshape(x, [batch, orig_rows, orig_cols * channels, hidden_size])

  return x
Exemplo n.º 4
0
def prepare_encoder(inputs, hparams, attention_type="local_1d"):
  """Prepare encoder for images."""
  x = prepare_image(inputs, hparams, name="enc_channels")
  # Add position signals.
  x = add_pos_signals(x, hparams, "enc_pos")
  x_shape = common_layers.shape_list(x)
  if attention_type == "local_1d":
    x = tf.reshape(x, [x_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size])
    x.set_shape([None, None, hparams.hidden_size])
  elif attention_type == "local_2d":
    x.set_shape([None, None, None, hparams.hidden_size])
  return x
Exemplo n.º 5
0
def get_self_attention_bias(x):
  """Creates masked self attention bias.

  Args:
    x: A tensor of shape [batch, length, depth]

  Returns:
    self_attention_bias: A tensor of shape [length, length, 1]
  """

  x_shape = common_layers.shape_list(x)
  self_attention_bias = common_attention.attention_bias_lower_triangle(
      x_shape[1])
  return self_attention_bias