Exemplo n.º 1
0
def prepare_decoder(targets, hparams):
    """Prepare decoder for images."""
    targets_shape = common_layers.shape_list(targets)
    channels = hparams.num_channels
    curr_infer_length = None

    # during training, images are [batch, IMG_LEN, IMG_LEN, 3].
    # At inference, they are [batch, curr_infer_length, 1, 1]
    if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
        curr_infer_length = targets_shape[1]
        if hparams.block_rastor_scan:
            assert hparams.img_len * channels % hparams.query_shape[1] == 0
            assert hparams.img_len % hparams.query_shape[0] == 0
            total_block_width = hparams.img_len * channels
            # Decoding is in block rastor scan order. We divide the image into
            # hparams.query_shape blocks and then decode each block in rastor scan.
            # To make that compatible with our inference pipeline, pad the target so
            # that rows is a multiple of query_shape and columns is a multiple of
            # hparams.img_len*channels
            curr_infer_length = targets_shape[1]
            block_padding_factor = total_block_width * hparams.query_shape[0]
            targets = tf.pad(
                targets,
                [[0, 0], [0, -curr_infer_length % block_padding_factor],
                 [0, 0], [0, 0]])

            num_blocks = total_block_width // hparams.query_shape[1]
            # Reshape the image to represent blocks
            target_blocks = tf.reshape(targets, [
                targets_shape[0], -1, num_blocks, hparams.query_shape[0],
                hparams.query_shape[1]
            ])
            # Transpose to read the image in 2D fashion.
            targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
        else:
            # add padding to make sure the size of targets is a multiple of img_height
            # times number of channels. This is  needed for positional encodings and
            # for doing the RGB lookup.
            padding_factor = channels * hparams.img_len
            targets = tf.pad(targets,
                             [[0, 0], [0, -curr_infer_length % padding_factor],
                              [0, 0], [0, 0]])
        targets = tf.reshape(targets,
                             [targets_shape[0], -1, hparams.img_len, channels])
    # Preprocess image
    x = prepare_image(targets, hparams, name="dec_channels")
    x_shape = common_layers.shape_list(x)
    if hparams.dec_attention_type == AttentionType.LOCAL_2D:
        x = common_attention.right_shift_blockwise(x, hparams.query_shape)
        x = add_pos_signals(x, hparams, "dec_pos")
    else:
        # Add position signals
        x = tf.reshape(
            x,
            [targets_shape[0], x_shape[1] * x_shape[2], hparams.hidden_size])
        x = common_layers.shift_right_3d(x)
        x = tf.reshape(
            x, [targets_shape[0], x_shape[1], x_shape[2], hparams.hidden_size])
        x = add_pos_signals(x, hparams, "dec_pos")
    return x, x_shape[1], x_shape[2]
def prepare_decoder(targets, hparams):
  """Prepare decoder for images."""
  targets_shape = common_layers.shape_list(targets)
  channels = hparams.num_channels
  curr_infer_length = None

  # during training, images are [batch, IMG_LEN, IMG_LEN, 3].
  # At inference, they are [batch, curr_infer_length, 1, 1]
  if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
    curr_infer_length = targets_shape[1]
    if hparams.block_raster_scan:
      assert hparams.img_len*channels % hparams.query_shape[1] == 0
      assert hparams.img_len % hparams.query_shape[0] == 0
      total_block_width = hparams.img_len*channels
      # Decoding is in block raster scan order. We divide the image into
      # hparams.query_shape blocks and then decode each block in raster scan.
      # To make that compatible with our inference pipeline, pad the target so
      # that rows is a multiple of query_shape and columns is a multiple of
      # hparams.img_len*channels
      curr_infer_length = targets_shape[1]
      block_padding_factor = total_block_width * hparams.query_shape[0]
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % block_padding_factor],
          [0, 0], [0, 0]])

      num_blocks = total_block_width // hparams.query_shape[1]
      # Reshape the image to represent blocks
      target_blocks = tf.reshape(
          targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0],
                    hparams.query_shape[1]])
      # Transpose to read the image in 2D fashion.
      targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
    else:
      # add padding to make sure the size of targets is a multiple of img_height
      # times number of channels. This is  needed for positional encodings and
      # for doing the RGB lookup.
      padding_factor = channels * hparams.img_len
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]])
    targets = tf.reshape(targets,
                         [targets_shape[0], -1, hparams.img_len, channels])
  # Preprocess image
  x = prepare_image(targets, hparams, name="dec_channels")
  x_shape = common_layers.shape_list(x)
  if (hparams.dec_attention_type == AttentionType.LOCAL_2D or
      hparams.dec_attention_type == AttentionType.LOCAL_BLOCK):
    x = common_attention.right_shift_blockwise(x, hparams.query_shape)
    x = add_pos_signals(x, hparams, "dec_pos")
  else:
    # Add position signals
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1]*x_shape[2], hparams.hidden_size])
    x = common_layers.shift_right_3d(x)
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1], x_shape[2], hparams.hidden_size])
    x = add_pos_signals(x, hparams, "dec_pos")
  x = common_layers.cast_like(x, targets)
  return x, x_shape[1], x_shape[2]