def body(self, features): hparams = copy.copy(self._hparams) inputs = features["inputs"] targets = features["targets"] targets_shape = common_layers.shape_list(targets) if not (tf.get_variable_scope().reuse or hparams.mode == tf.contrib.learn.ModeKeys.INFER): tf.summary.image("targets", targets, max_outputs=1) decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) # Add class label to decoder input. if not hparams.unconditional: decoder_input += tf.reshape(inputs, [targets_shape[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def body(self, features): hparams = copy.copy(self._hparams) targets = features["targets"] inputs = features["inputs"] if not (tf.get_variable_scope().reuse or hparams.mode == tf.estimator.ModeKeys.PREDICT): tf.summary.image("inputs", inputs, max_outputs=1) tf.summary.image("targets", targets, max_outputs=1) encoder_input = cia.prepare_encoder(inputs, hparams) encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def transformer_image_decoder(encoder_output, ed_attention_bias, targets, hparams, name="transformer_dec"): """Original Transformer decoder.""" with tf.variable_scope(name): batch_size = common_layers.shape_list(targets)[0] # Reshape targets as b, 32, 32, 3*hidden size]. targets = tf.reshape(targets, [ batch_size, hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) # Prepare decoder inputs and bias. This also shifts targets and adds 2D # position embeddings to target. decoder_input, _, _ = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], hparams.img_len, hparams.img_len * hparams.num_channels, hparams.hidden_size ]) return decoder_output
def body(self, features): hparams = copy.copy(self._hparams) inputs = features["inputs"] targets = features["targets"] if not (tf.get_variable_scope().reuse or hparams.mode == tf.contrib.learn.ModeKeys.INFER): tf.summary.image("targets", tf.to_float(targets), max_outputs=1) # Extra losses list if we want to use moe. losses = [] # Prepare decoder inputs and bias. decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: decoder_input += tf.reshape(inputs, [ common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size ]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, losses=losses, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) if losses: return output, {"extra_loss": tf.add_n(losses)} else: return output
def body(self, features): hparams = copy.copy(self._hparams) inputs = features["inputs"] targets = features["targets"] # Prepare decoder inputs and bias. decoder_input, _, _ = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: decoder_input += tf.reshape(inputs, [ common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size ]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") # reshape it into [batch, height, width, depth] decoder_output = tf.reshape(decoder_output, tf.shape(targets)) # there are 10 sets of parameters that you need to produce, location, scale, # and coefficient parameter for each output = tf.layers.dense(decoder_output, hparams.num_mixtures * 10, use_bias=False, activation=None, name="output_mixtures_conv") # TODO(avaswani) Figure out if we need residuals or layer norm return output
def body(self, features): hparams = copy.copy(self._hparams) targets = features["targets"] if (hparams.likelihood == cia.DistributionType.DMOL and hparams.num_channels != 1): raise ValueError("When using DMOL for the likelihood, bottom function " " must be identity and num_channels must be 1.") if (not tf.get_variable_scope().reuse and hparams.mode != tf.estimator.ModeKeys.PREDICT): tf.summary.image("targets", tf.to_float(targets), max_outputs=1) # Extra losses list if we want to use moe. losses = [] # Prepare decoder inputs and bias. decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: inputs = features["inputs"] decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, losses=losses, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) if losses: return output, {"extra_loss": tf.add_n(losses)} else: return output
def generator(self, inputs, targets): """From tensor2tensor.models.img2img_transformer_2d.""" hparams = copy.copy(self._hparams) encoder_input = cia.prepare_encoder(inputs, hparams) encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def body(self, features): hparams = copy.copy(self._hparams) inputs = features["inputs"] targets = features["targets"] targets_shape = common_layers.shape_list(targets) if not (tf.get_variable_scope().reuse or hparams.mode == tf.estimator.ModeKeys.PREDICT): tf.summary.image("targets", targets, max_outputs=1) decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: decoder_input += tf.reshape( inputs, [targets_shape[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def transformer_latent_decoder(encoder_output, ed_attention_bias, targets, hparams, name="transformer_latent_dec"): """Original Transformer decoder.""" with tf.variable_scope(name): batch_size = common_layers.shape_list(targets)[0] compress_ratio = 2**(hparams.num_compress_steps // 2) # Reshape targets as b, 32, 32, 3*hidden size]. targets = tf.reshape(targets, [ batch_size, hparams.img_len / compress_ratio, (hparams.img_len * hparams.num_latents) / compress_ratio, hparams.hidden_size ]) # Prepare decoder inputs and bias. decoder_input, _, _ = cia.prepare_decoder(targets, hparams) # hparams.num_channels = 3 decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], (hparams.img_len * hparams.img_len * hparams.num_latents) / (2**hparams.num_compress_steps), hparams.hidden_size ]) return decoder_output
def body(self, features): hparams = copy.copy(self._hparams) targets = features["targets"] inputs = features["inputs"] if not (tf.get_variable_scope().reuse or hparams.mode == tf.contrib.learn.ModeKeys.INFER): tf.summary.image("inputs", inputs, max_outputs=1) tf.summary.image("targets", targets, max_outputs=1) encoder_input = cia.prepare_encoder(inputs, hparams) encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def body(self, features): assert self._hparams.block_size > 0 assert not common_layers.is_xla_compiled() hparams = copy.copy(self._hparams) targets = features["targets"] inputs = features["inputs"] if not (tf.get_variable_scope().reuse or hparams.mode == tf.estimator.ModeKeys.PREDICT): tf.summary.image("inputs", inputs, max_outputs=1) tf.summary.image("targets", targets, max_outputs=1) encoder_input = cia.prepare_encoder(inputs, hparams) encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") assert not isinstance(decoder_output, tuple) assert len(decoder_output.shape) == 4 relu_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(self._hparams, "relu_dropout_broadcast_dims", ""))) with tf.variable_scope("block_size_%d" % self._hparams.block_size): tf.logging.info("Using block_size %d", self._hparams.block_size) block_output = common_layers.dense_relu_dense( decoder_output, self._hparams.block_size * self._hparams.filter_size, self._hparams.block_size * self._hparams.hidden_size, dropout=self._hparams.relu_dropout, dropout_broadcast_dims=relu_dropout_broadcast_dims) batch_size, rows, cols = common_layers.shape_list(decoder_output)[:3] decoder_output = tf.reshape( decoder_output, [batch_size, rows, cols, 1, self._hparams.hidden_size]) block_output = tf.reshape(block_output, [ batch_size, rows, cols, self._hparams.block_size, self._hparams.hidden_size ]) block_output = common_layers.layer_postprocess(decoder_output, block_output, self._hparams) return block_output
def _build_layers_v2(self, input_dict, num_outputs, options): # print(input_dict) # exit(222) hparams = copy.copy(options["custom_options"]["hparams"]) #targets = tf.placeholder( # tf.float32, [None, 11, 11, 1]) targets = input_dict["prev_actions"] inputs = input_dict["obs"] # if not (tf.get_variable_scope().reuse or # hparams.mode == tf.estimator.ModeKeys.PREDICT): # tf.summary.image("inputs", inputs, max_outputs=1) # tf.summary.image("targets", targets, max_outputs=1) with tf.name_scope('enc_prep'): encoder_input = cia.prepare_encoder(inputs, hparams) with tf.name_scope('enc_layers'): encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") with tf.name_scope('dec_prep'): decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) with tf.name_scope('dec_layers'): decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") #with tf.name_scope('dec_out'): # output = cia.create_output(decoder_output, rows, cols, targets, hparams) #print(output, encoder_output) out_size, kernel, stride = [32, [3, 3], 2] activation = get_activation_fn(options.get("conv_activation")) fc1 = slim.conv2d( decoder_output, out_size, kernel, stride, activation_fn=activation, padding="VALID", scope="fc1") fc2 = slim.conv2d( fc1, num_outputs, [1, 1], activation_fn=None, normalizer_fn=None, scope="fc2") #print(fc1, fc2) #print(flatten(fc1), flatten(fc2)) #exit(123) return flatten(fc2), flatten(fc1)
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None): """Original Transformer decoder.""" with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout( decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [ tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) # Prepare decoder inputs and bias. decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape(inputs, [ common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size ]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, bias, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape( decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None): """Original Transformer decoder.""" with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels*hparams.hidden_size]) # Prepare decoder inputs and bias. decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, bias, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
def transformer_latent_decoder(x, encoder_output, ed_attention_bias, hparams, name="transformer_latent_dec"): """Transformer decoder over latents using latent_attention_type. Args: x: Tensor of shape [batch, height, width, hidden_dim]. encoder_output: Tensor, encoder output of shape [batch, length, hidden_dim]. ed_attention_bias: Tensor, bias for x. hparams: Dict, hyperparameters. name: string, variable scope. Returns: x: Tensor of shape [batch, height, width, hidden_dim]. """ with tf.variable_scope(name): batch_size = common_layers.shape_list(x)[0] compress_ratio = 2**(hparams.num_compress_steps // 2) # Reshape targets as b, 32, 32, 3*hidden size]. x = tf.reshape(x, [ batch_size, hparams.img_len / compress_ratio, (hparams.img_len * hparams.num_latents) / compress_ratio, hparams.hidden_size ]) # Prepare decoder inputs and bias. decoder_input, _, _ = cia.prepare_decoder(x, hparams) # hparams.num_channels = 3 decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], (hparams.img_len * hparams.img_len * hparams.num_latents) / (2**hparams.num_compress_steps), hparams.hidden_size ]) return decoder_output
def transformer_latent_decoder(x, encoder_output, ed_attention_bias, hparams, name="transformer_latent_dec"): """Transformer decoder over latents using latent_attention_type. Args: x: Tensor of shape [batch, ...], and whose size is batch * length_q * hparams.hidden_size. Here, length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: Tensor of shape [batch, length_q, hparams.hidden_size]. """ with tf.variable_scope(name): batch_size = common_layers.shape_list(x)[0] compress_ratio = 2**(hparams.num_compress_steps // 2) x = tf.reshape(x, [ batch_size, hparams.img_len / compress_ratio, (hparams.img_len * hparams.num_latents) / compress_ratio, hparams.hidden_size ]) decoder_input, _, _ = cia.prepare_decoder(x, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], (hparams.img_len * hparams.img_len * hparams.num_latents) / (2**hparams.num_compress_steps), hparams.hidden_size ]) return decoder_output
def transformer_latent_decoder(x, encoder_output, ed_attention_bias, hparams, name=None): """Transformer decoder over latents using latent_attention_type. Args: x: Tensor of shape [batch, length_q, hparams.hidden_size]. length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: HParams. name: string, variable scope. Returns: Tensor of shape [batch, length_q, hparams.hidden_size]. """ with tf.variable_scope(name, default_name="transformer_latent_dec"): batch_size = common_layers.shape_list(x)[0] compressed_img_len = (hparams.img_len // 2**(hparams.num_compress_steps // 2)) x = tf.reshape(x, [ batch_size, compressed_img_len, compressed_img_len * hparams.num_latents, hparams.hidden_size ]) decoder_input, _, _ = cia.prepare_decoder(x, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output = tf.reshape(decoder_output, [ batch_size, compressed_img_len**2 * hparams.num_latents, hparams.hidden_size ]) return decoder_output
def transformer_image_decoder(x, encoder_output, ed_attention_bias, hparams, name="transformer_dec"): """Transformer image decoder over inputs with local attention. Args: x: Tensor of shape [batch, height, width, hidden_dim]. encoder_output: Tensor, encoder output of shape [batch, length, hidden_dim]. ed_attention_bias: Tensor, bias for x. hparams: Dict, hyperparameters. name: string, variable scope. Returns: x: Tensor of shape [batch, height, width, hidden_dim]. """ with tf.variable_scope(name): batch_size = common_layers.shape_list(x)[0] # Reshape targets as b, 32, 32, 3*hidden size]. targets = tf.reshape(x, [ batch_size, hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) # Prepare decoder inputs and bias. This also shifts targets and adds 2D # position embeddings to target. decoder_input, _, _ = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], hparams.img_len, hparams.img_len * hparams.num_channels, hparams.hidden_size ]) return decoder_output
def transformer_latent_decoder(x, encoder_output, ed_attention_bias, hparams, name=None): """Transformer decoder over latents using latent_attention_type. Args: x: Tensor of shape [batch, length_q, hparams.hidden_size]. length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: Tensor of shape [batch, length_q, hparams.hidden_size]. """ with tf.variable_scope(name, default_name="transformer_latent_dec"): batch_size = common_layers.shape_list(x)[0] compressed_img_len = (hparams.img_len // 2**(hparams.num_compress_steps // 2)) x = tf.reshape(x, [batch_size, compressed_img_len, compressed_img_len * hparams.num_latents, hparams.hidden_size]) decoder_input, _, _ = cia.prepare_decoder(x, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output = tf.reshape(decoder_output, [batch_size, compressed_img_len**2 * hparams.num_latents, hparams.hidden_size]) return decoder_output
def transformer_image_decoder(x, encoder_output, ed_attention_bias, hparams, name="transformer_dec"): """Transformer image decoder over inputs with local attention. Args: x: Tensor of shape [batch, ...], and whose size is batch * height * width * hparams.num_channels * hparams.hidden_size. encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: Tensor of shape [batch, height, width * hparams.num_channels, hparams.hidden_size]. """ with tf.variable_scope(name): batch_size = common_layers.shape_list(x)[0] targets = tf.reshape(x, [ batch_size, hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) decoder_input, _, _ = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], hparams.img_len, hparams.img_len * hparams.num_channels, hparams.hidden_size ]) return decoder_output
def transformer_image_decoder(targets, encoder_output, ed_attention_bias, hparams, name=None): """Transformer image decoder over targets with local attention. Args: targets: Tensor of shape [batch, ...], and whose size is batch * height * width * hparams.num_channels * hparams.hidden_size. encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: Tensor of shape [batch, height, width * hparams.num_channels, hparams.hidden_size]. """ with tf.variable_scope(name, default_name="transformer_dec"): batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size]) decoder_input, _, _ = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output = tf.reshape(decoder_output, [batch_size, hparams.img_len, hparams.img_len * hparams.num_channels, hparams.hidden_size]) return decoder_output
def body(self, features): hparams = copy.copy(self._hparams) targets = features["targets"] if (hparams.likelihood == cia.DistributionType.DMOL and (hparams.modality["targets"] != modalities.ImageChannelBottomIdentityModality or hparams.num_channels != 1)): raise ValueError("When using DMOL for the likelihood,modality['targets'] " "must be ImageChannelBottomIdentityModality and " "num_channels must be 1.") if (not tf.get_variable_scope().reuse and hparams.mode != tf.contrib.learn.ModeKeys.INFER and hparams.modality["targets"] != modalities.ImageChannelBottomIdentityModality): tf.summary.image("targets", tf.to_float(targets), max_outputs=1) # Extra losses list if we want to use moe. losses = [] # Prepare decoder inputs and bias. decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: inputs = features["inputs"] decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, losses=losses, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) if losses: return output, {"extra_loss": tf.add_n(losses)} else: return output