def prepare_decoder(targets, hparams): """Prepare decoder for images.""" targets_shape = common_layers.shape_list(targets) channels = hparams.num_channels curr_infer_length = None # during training, images are [batch, IMG_LEN, IMG_LEN, 3]. # At inference, they are [batch, curr_infer_length, 1, 1] if hparams.mode == tf.estimator.ModeKeys.PREDICT: curr_infer_length = targets_shape[1] if hparams.block_raster_scan: assert hparams.img_len*channels % hparams.query_shape[1] == 0 assert hparams.img_len % hparams.query_shape[0] == 0 total_block_width = hparams.img_len*channels # Decoding is in block raster scan order. We divide the image into # hparams.query_shape blocks and then decode each block in raster scan. # To make that compatible with our inference pipeline, pad the target so # that rows is a multiple of query_shape and columns is a multiple of # hparams.img_len*channels curr_infer_length = targets_shape[1] block_padding_factor = total_block_width * hparams.query_shape[0] targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % block_padding_factor], [0, 0], [0, 0]]) num_blocks = total_block_width // hparams.query_shape[1] # Reshape the image to represent blocks target_blocks = tf.reshape( targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0], hparams.query_shape[1]]) # Transpose to read the image in 2D fashion. targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4]) else: # add padding to make sure the size of targets is a multiple of img_height # times number of channels. This is needed for positional encodings and # for doing the RGB lookup. padding_factor = channels * hparams.img_len targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]]) targets = tf.reshape(targets, [targets_shape[0], -1, hparams.img_len, channels]) # Preprocess image x = prepare_image(targets, hparams, name="dec_channels") x_shape = common_layers.shape_list(x) if (hparams.dec_attention_type == AttentionType.LOCAL_2D or hparams.dec_attention_type == AttentionType.LOCAL_BLOCK): x = common_attention.right_shift_blockwise(x, hparams.query_shape) x = add_pos_signals(x, hparams, "dec_pos") else: # Add position signals x = tf.reshape(x, [targets_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size]) x = common_layers.shift_right_3d(x) x = tf.reshape(x, [targets_shape[0], x_shape[1], x_shape[2], hparams.hidden_size]) x = add_pos_signals(x, hparams, "dec_pos") x = common_layers.cast_like(x, targets) return x, x_shape[1], x_shape[2]
def postprocess_image(x, rows, cols, hparams): """Postprocessing after decoding. Args: x: Tensor of shape [batch, ...], where ... can be any rank such that the number of elements in x is batch * rows * cols * hparams.hidden_size. rows: Integer representing number of rows in a 2-D data point. cols: Integer representing number of columns in a 2-D data point. hparams: HParams set. Returns: Tensor of shape [batch, rows, cols, depth], where depth is hparams.num_mixtures * 10 if hparams.likelihood is DMOL, otherwise 256. In the special case of inference and block raster scan order, it is a Tensor of shape [batch, num_blocks_rows, num_block_cols, block_length, block_width, depth]. """ batch = common_layers.shape_list(x)[0] x = tf.reshape(x, [batch, rows, cols, hparams.hidden_size]) likelihood = getattr(hparams, "likelihood", DistributionType.CAT) if likelihood == DistributionType.DMOL: depth = hparams.num_mixtures * 10 targets = tf.layers.dense(x, depth, use_bias=False, activation=None, name="output_conv") else: depth = 256 targets = tf.layers.dense(x, depth, use_bias=True, activation=None, name="output_conv") if (hparams.mode == tf.estimator.ModeKeys.PREDICT and hparams.block_raster_scan): y = targets yshape = common_layers.shape_list(y) block_length = hparams.query_shape[0] block_width = hparams.query_shape[1] # Break into block row wise. y = tf.reshape(y, [batch, yshape[1] // block_length, block_length, yshape[2], depth]) yshape = common_layers.shape_list(y) # Break into blocks width wise. y_blocks = tf.reshape(y, [batch, yshape[1], yshape[2], yshape[3] // block_width, block_width, depth]) # Reshape targets as [batch, num_blocks_rows, num_block_cols, block_length, # block_width, depth]. targets = tf.transpose(y_blocks, [0, 1, 3, 2, 4, 5]) return targets
def prepare_image(inputs, hparams, name=None): """Prepare image.""" inputs_shape = common_layers.shape_list(inputs) batch = inputs_shape[0] orig_rows = inputs_shape[1] orig_cols = inputs_shape[2] channels = hparams.num_channels hidden_size = hparams.hidden_size # TODO(trandustin): Check via modalities.ModalityType.IDENTITY and not str. # The current implementation is to avoid circular imports, modalities -> # discretization -> common_image_attention -> modalities. if "targets" in hparams.modality: target_modality_name = hparams.modality["targets"] if not isinstance(target_modality_name, str): target_modality_name = target_modality_name.__name__ else: target_modality_name = None if target_modality_name == "IdentityModality": inputs = tf.to_int32(inputs) x = get_channel_embeddings(channels, inputs, hidden_size, name=name) else: x = inputs x = tf.reshape(x, [batch, orig_rows, orig_cols * channels, hidden_size]) return x
def prepare_encoder(inputs, hparams, attention_type="local_1d"): """Prepare encoder for images.""" x = prepare_image(inputs, hparams, name="enc_channels") # Add position signals. x = add_pos_signals(x, hparams, "enc_pos") x_shape = common_layers.shape_list(x) if attention_type == "local_1d": x = tf.reshape(x, [x_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size]) x.set_shape([None, None, hparams.hidden_size]) elif attention_type == "local_2d": x.set_shape([None, None, None, hparams.hidden_size]) return x
def get_self_attention_bias(x): """Creates masked self attention bias. Args: x: A tensor of shape [batch, length, depth] Returns: self_attention_bias: A tensor of shape [length, length, 1] """ x_shape = common_layers.shape_list(x) self_attention_bias = common_attention.attention_bias_lower_triangle( x_shape[1]) return self_attention_bias