def body(self, features): hparams = copy.copy(self._hparams) inputs = features["inputs"] targets = features["targets"] targets_shape = common_layers.shape_list(targets) if not (tf.get_variable_scope().reuse or hparams.mode == tf.estimator.ModeKeys.PREDICT): tf.summary.image("targets", targets, max_outputs=1) decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: decoder_input += tf.reshape( inputs, [targets_shape[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def body(self, features): hparams = copy.copy(self._hparams) targets = features["targets"] inputs = features["inputs"] if not (tf.get_variable_scope().reuse or hparams.mode == tf.estimator.ModeKeys.PREDICT): tf.summary.image("inputs", inputs, max_outputs=1) tf.summary.image("targets", targets, max_outputs=1) encoder_input = cia.prepare_encoder(inputs, hparams) encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def prior_fn(shifted_codes): """Calculates prior logits on discrete latents. Args: shifted_codes: A binary `Tensor` of shape [num_samples, batch_size, latent_size, num_codes], shifted by one to enable autoregressive calculation. Returns: prior_dist: Multinomial distribution with prior logits coming from Transformer applied to shifted input. """ with tf.variable_scope("transformer_prior", reuse=tf.AUTO_REUSE): dense_shifted_codes = tf.reduce_sum( tf.reshape(embedding_layer, [1, 1, 1, num_codes, code_size]) * shifted_codes[Ellipsis, tf.newaxis], axis=-2) transformed_codes = cia.transformer_decoder_layers( inputs=dense_shifted_codes, encoder_output=None, num_layers=hparams.num_layers, hparams=hparams, attention_type=cia.AttentionType.LOCAL_1D) logits = tf.reduce_sum( tf.reshape(embedding_layer, [1, 1, 1, num_codes, code_size]) * transformed_codes[Ellipsis, tf.newaxis, :], axis=-1) prior_dist = tfd.Multinomial(total_count=1., logits=logits) return prior_dist
def body(self, features): hparams = copy.copy(self._hparams) targets = features["targets"] inputs = features["inputs"] if not (tf.get_variable_scope().reuse or hparams.mode == tf.contrib.learn.ModeKeys.INFER): tf.summary.image("inputs", inputs, max_outputs=1) tf.summary.image("targets", targets, max_outputs=1) encoder_input = cia.prepare_encoder(inputs, hparams) encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def body(self, features): hparams = copy.copy(self._hparams) inputs = features["inputs"] targets = features["targets"] if not (tf.get_variable_scope().reuse or hparams.mode == tf.contrib.learn.ModeKeys.INFER): tf.summary.image("targets", tf.to_float(targets), max_outputs=1) # Extra losses list if we want to use moe. losses = [] # Prepare decoder inputs and bias. decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: decoder_input += tf.reshape(inputs, [ common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size ]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, losses=losses, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) if losses: return output, {"extra_loss": tf.add_n(losses)} else: return output
def body(self, features): hparams = copy.copy(self._hparams) inputs = features["inputs"] targets = features["targets"] # Prepare decoder inputs and bias. decoder_input, _, _ = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: decoder_input += tf.reshape(inputs, [ common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size ]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") # reshape it into [batch, height, width, depth] decoder_output = tf.reshape(decoder_output, tf.shape(targets)) # there are 10 sets of parameters that you need to produce, location, scale, # and coefficient parameter for each output = tf.layers.dense(decoder_output, hparams.num_mixtures * 10, use_bias=False, activation=None, name="output_mixtures_conv") # TODO(avaswani) Figure out if we need residuals or layer norm return output
def body(self, features): hparams = copy.copy(self._hparams) inputs = features["inputs"] targets = features["targets"] targets_shape = common_layers.shape_list(targets) if not (tf.get_variable_scope().reuse or hparams.mode == tf.contrib.learn.ModeKeys.INFER): tf.summary.image("targets", targets, max_outputs=1) decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) # Add class label to decoder input. if not hparams.unconditional: decoder_input += tf.reshape(inputs, [targets_shape[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def transformer_latent_decoder(encoder_output, ed_attention_bias, targets, hparams, name="transformer_latent_dec"): """Original Transformer decoder.""" with tf.variable_scope(name): batch_size = common_layers.shape_list(targets)[0] compress_ratio = 2**(hparams.num_compress_steps // 2) # Reshape targets as b, 32, 32, 3*hidden size]. targets = tf.reshape(targets, [ batch_size, hparams.img_len / compress_ratio, (hparams.img_len * hparams.num_latents) / compress_ratio, hparams.hidden_size ]) # Prepare decoder inputs and bias. decoder_input, _, _ = cia.prepare_decoder(targets, hparams) # hparams.num_channels = 3 decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], (hparams.img_len * hparams.img_len * hparams.num_latents) / (2**hparams.num_compress_steps), hparams.hidden_size ]) return decoder_output
def body(self, features): hparams = copy.copy(self._hparams) targets = features["targets"] if (hparams.likelihood == cia.DistributionType.DMOL and hparams.num_channels != 1): raise ValueError("When using DMOL for the likelihood, bottom function " " must be identity and num_channels must be 1.") if (not tf.get_variable_scope().reuse and hparams.mode != tf.estimator.ModeKeys.PREDICT): tf.summary.image("targets", tf.to_float(targets), max_outputs=1) # Extra losses list if we want to use moe. losses = [] # Prepare decoder inputs and bias. decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: inputs = features["inputs"] decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, losses=losses, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) if losses: return output, {"extra_loss": tf.add_n(losses)} else: return output
def generator(self, inputs, targets): """From tensor2tensor.models.img2img_transformer_2d.""" hparams = copy.copy(self._hparams) encoder_input = cia.prepare_encoder(inputs, hparams) encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) return output
def transformer_image_decoder(encoder_output, ed_attention_bias, targets, hparams, name="transformer_dec"): """Original Transformer decoder.""" with tf.variable_scope(name): batch_size = common_layers.shape_list(targets)[0] # Reshape targets as b, 32, 32, 3*hidden size]. targets = tf.reshape(targets, [ batch_size, hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) # Prepare decoder inputs and bias. This also shifts targets and adds 2D # position embeddings to target. decoder_input, _, _ = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], hparams.img_len, hparams.img_len * hparams.num_channels, hparams.hidden_size ]) return decoder_output
def body(self, features): assert self._hparams.block_size > 0 assert not common_layers.is_xla_compiled() hparams = copy.copy(self._hparams) targets = features["targets"] inputs = features["inputs"] if not (tf.get_variable_scope().reuse or hparams.mode == tf.estimator.ModeKeys.PREDICT): tf.summary.image("inputs", inputs, max_outputs=1) tf.summary.image("targets", targets, max_outputs=1) encoder_input = cia.prepare_encoder(inputs, hparams) encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") assert not isinstance(decoder_output, tuple) assert len(decoder_output.shape) == 4 relu_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(self._hparams, "relu_dropout_broadcast_dims", ""))) with tf.variable_scope("block_size_%d" % self._hparams.block_size): tf.logging.info("Using block_size %d", self._hparams.block_size) block_output = common_layers.dense_relu_dense( decoder_output, self._hparams.block_size * self._hparams.filter_size, self._hparams.block_size * self._hparams.hidden_size, dropout=self._hparams.relu_dropout, dropout_broadcast_dims=relu_dropout_broadcast_dims) batch_size, rows, cols = common_layers.shape_list(decoder_output)[:3] decoder_output = tf.reshape( decoder_output, [batch_size, rows, cols, 1, self._hparams.hidden_size]) block_output = tf.reshape(block_output, [ batch_size, rows, cols, self._hparams.block_size, self._hparams.hidden_size ]) block_output = common_layers.layer_postprocess(decoder_output, block_output, self._hparams) return block_output
def _build_layers_v2(self, input_dict, num_outputs, options): # print(input_dict) # exit(222) hparams = copy.copy(options["custom_options"]["hparams"]) #targets = tf.placeholder( # tf.float32, [None, 11, 11, 1]) targets = input_dict["prev_actions"] inputs = input_dict["obs"] # if not (tf.get_variable_scope().reuse or # hparams.mode == tf.estimator.ModeKeys.PREDICT): # tf.summary.image("inputs", inputs, max_outputs=1) # tf.summary.image("targets", targets, max_outputs=1) with tf.name_scope('enc_prep'): encoder_input = cia.prepare_encoder(inputs, hparams) with tf.name_scope('enc_layers'): encoder_output = cia.transformer_encoder_layers( encoder_input, hparams.num_encoder_layers, hparams, attention_type=hparams.enc_attention_type, name="encoder") with tf.name_scope('dec_prep'): decoder_input, rows, cols = cia.prepare_decoder( targets, hparams) with tf.name_scope('dec_layers'): decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") #with tf.name_scope('dec_out'): # output = cia.create_output(decoder_output, rows, cols, targets, hparams) #print(output, encoder_output) out_size, kernel, stride = [32, [3, 3], 2] activation = get_activation_fn(options.get("conv_activation")) fc1 = slim.conv2d( decoder_output, out_size, kernel, stride, activation_fn=activation, padding="VALID", scope="fc1") fc2 = slim.conv2d( fc1, num_outputs, [1, 1], activation_fn=None, normalizer_fn=None, scope="fc2") #print(fc1, fc2) #print(flatten(fc1), flatten(fc2)) #exit(123) return flatten(fc2), flatten(fc1)
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None): """Original Transformer decoder.""" with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout( decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [ tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) # Prepare decoder inputs and bias. decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape(inputs, [ common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size ]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, bias, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape( decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None): """Original Transformer decoder.""" with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels*hparams.hidden_size]) # Prepare decoder inputs and bias. decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, bias, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
def transformer_latent_decoder(x, encoder_output, ed_attention_bias, hparams, name="transformer_latent_dec"): """Transformer decoder over latents using latent_attention_type. Args: x: Tensor of shape [batch, height, width, hidden_dim]. encoder_output: Tensor, encoder output of shape [batch, length, hidden_dim]. ed_attention_bias: Tensor, bias for x. hparams: Dict, hyperparameters. name: string, variable scope. Returns: x: Tensor of shape [batch, height, width, hidden_dim]. """ with tf.variable_scope(name): batch_size = common_layers.shape_list(x)[0] compress_ratio = 2**(hparams.num_compress_steps // 2) # Reshape targets as b, 32, 32, 3*hidden size]. x = tf.reshape(x, [ batch_size, hparams.img_len / compress_ratio, (hparams.img_len * hparams.num_latents) / compress_ratio, hparams.hidden_size ]) # Prepare decoder inputs and bias. decoder_input, _, _ = cia.prepare_decoder(x, hparams) # hparams.num_channels = 3 decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], (hparams.img_len * hparams.img_len * hparams.num_latents) / (2**hparams.num_compress_steps), hparams.hidden_size ]) return decoder_output
def transformer_latent_decoder(x, encoder_output, ed_attention_bias, hparams, name="transformer_latent_dec"): """Transformer decoder over latents using latent_attention_type. Args: x: Tensor of shape [batch, ...], and whose size is batch * length_q * hparams.hidden_size. Here, length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: Tensor of shape [batch, length_q, hparams.hidden_size]. """ with tf.variable_scope(name): batch_size = common_layers.shape_list(x)[0] compress_ratio = 2**(hparams.num_compress_steps // 2) x = tf.reshape(x, [ batch_size, hparams.img_len / compress_ratio, (hparams.img_len * hparams.num_latents) / compress_ratio, hparams.hidden_size ]) decoder_input, _, _ = cia.prepare_decoder(x, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], (hparams.img_len * hparams.img_len * hparams.num_latents) / (2**hparams.num_compress_steps), hparams.hidden_size ]) return decoder_output
def testTransformerDecoderLayersGlobal(self): one_hot_data = tf.constant([[[0., 1.], [1., 0.]], [[0., 1.], [1., 0.]], [[1., 0.], [1., 0.]]]) hparams = common_hparams.basic_params1() hparams.hidden_size = 4 hparams.num_layers = 1 hparams.layer_prepostprocess_dropout = 0. hparams.add_hparam("attention_key_channels", None) hparams.add_hparam("attention_value_channels", None) hparams.add_hparam("num_heads", 1) hparams.add_hparam("attention_dropout", 0.) hparams.add_hparam("shared_rel", False) hparams.add_hparam("block_width", 1) hparams.add_hparam("block_length", 1) hparams.add_hparam("q_filter_width", 1) hparams.add_hparam("kv_filter_width", 1) hparams.add_hparam("filter_size", 16) hparams.add_hparam("ffn_layer", "conv_hidden_relu") hparams.add_hparam("relu_dropout", 0.) conv_1d = tf.keras.layers.Conv1D(filters=hparams.hidden_size, kernel_size=1, use_bias=False) shifted_data = tf.pad(one_hot_data, [[0, 0], [1, 0], [0, 0]])[..., :-1, :] net = conv_1d(shifted_data) output = common_image_attention.transformer_decoder_layers( inputs=net, encoder_output=None, num_layers=hparams.num_layers, hparams=hparams, self_attention_bias=common_image_attention.get_self_attention_bias( net), attention_type=common_image_attention.AttentionType.GLOBAL) self.evaluate(tf.global_variables_initializer()) output_val = self.evaluate(output) # The outputs for the padded dimension should be equal across all data. self.assertAllEqual(output_val[0, 0], output_val[1, 0]) self.assertAllEqual(output_val[1, 0], output_val[2, 0]) # The first and second elements of the batch are identical, so they should # have the same outputs for the second latent dimension as well. self.assertAllEqual(output_val[0, 1], output_val[1, 1])
def iaf_scale_from_transformer(shifted_codes, code_size, name=None): """Returns unconstrained IAF scale tensor generated by Transformer. Args: shifted_codes: Tensor with shape [num_samples, batch_size, latent_size, num_codes], with the first latent_size dimension a tensor of zeros and the others shifted down one (with the original last latent dimension missing). code_size: Size of each latent code. name: String used for name scope. Returns: unconstrained_scale: Tensor with shape [latent_size, latent_size], generated by passing shifted codes through Transformer decoder layers. It can take on any real value as it will later be passed through a softplus. """ with tf.name_scope(name, default_name="iaf_scale_from_transformer") as name: num_codes = shifted_codes.shape[-1] embedding_layer = tf.get_variable(name + "iaf_embedding", [num_codes, code_size], dtype=tf.float32) hparams = transformer_hparams(hidden_size=code_size) dense_shifted_codes = tf.reduce_sum( tf.reshape(embedding_layer, [1, 1, 1, num_codes, code_size]) * shifted_codes[Ellipsis, tf.newaxis], axis=-2) transformed_codes = cia.transformer_decoder_layers( inputs=dense_shifted_codes, encoder_output=None, num_layers=hparams.num_layers, hparams=hparams, attention_type=cia.AttentionType.LOCAL_1D, name=name) unconstrained_scale = tf.reduce_sum( tf.reshape(embedding_layer, [1, 1, 1, num_codes, code_size]) * transformed_codes[Ellipsis, tf.newaxis, :], axis=-1) # Multiplying by scale_weight allows identity initialization. scale_weight = tf.get_variable(name + "scale_weight", [], dtype=tf.float32, initializer=tf.zeros_initializer()) unconstrained_scale = unconstrained_scale * scale_weight return unconstrained_scale
def transformer_latent_decoder(x, encoder_output, ed_attention_bias, hparams, name=None): """Transformer decoder over latents using latent_attention_type. Args: x: Tensor of shape [batch, length_q, hparams.hidden_size]. length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: Tensor of shape [batch, length_q, hparams.hidden_size]. """ with tf.variable_scope(name, default_name="transformer_latent_dec"): batch_size = common_layers.shape_list(x)[0] compressed_img_len = (hparams.img_len // 2**(hparams.num_compress_steps // 2)) x = tf.reshape(x, [batch_size, compressed_img_len, compressed_img_len * hparams.num_latents, hparams.hidden_size]) decoder_input, _, _ = cia.prepare_decoder(x, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output = tf.reshape(decoder_output, [batch_size, compressed_img_len**2 * hparams.num_latents, hparams.hidden_size]) return decoder_output
def transformer_image_decoder(x, encoder_output, ed_attention_bias, hparams, name="transformer_dec"): """Transformer image decoder over inputs with local attention. Args: x: Tensor of shape [batch, height, width, hidden_dim]. encoder_output: Tensor, encoder output of shape [batch, length, hidden_dim]. ed_attention_bias: Tensor, bias for x. hparams: Dict, hyperparameters. name: string, variable scope. Returns: x: Tensor of shape [batch, height, width, hidden_dim]. """ with tf.variable_scope(name): batch_size = common_layers.shape_list(x)[0] # Reshape targets as b, 32, 32, 3*hidden size]. targets = tf.reshape(x, [ batch_size, hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) # Prepare decoder inputs and bias. This also shifts targets and adds 2D # position embeddings to target. decoder_input, _, _ = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], hparams.img_len, hparams.img_len * hparams.num_channels, hparams.hidden_size ]) return decoder_output
def transformer_latent_decoder(x, encoder_output, ed_attention_bias, hparams, name=None): """Transformer decoder over latents using latent_attention_type. Args: x: Tensor of shape [batch, length_q, hparams.hidden_size]. length_q is the latent length, which is height * width * hparams.num_latents / (2**hparams.num_compress_steps). encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: HParams. name: string, variable scope. Returns: Tensor of shape [batch, length_q, hparams.hidden_size]. """ with tf.variable_scope(name, default_name="transformer_latent_dec"): batch_size = common_layers.shape_list(x)[0] compressed_img_len = (hparams.img_len // 2**(hparams.num_compress_steps // 2)) x = tf.reshape(x, [ batch_size, compressed_img_len, compressed_img_len * hparams.num_latents, hparams.hidden_size ]) decoder_input, _, _ = cia.prepare_decoder(x, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_latent_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.latent_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output = tf.reshape(decoder_output, [ batch_size, compressed_img_len**2 * hparams.num_latents, hparams.hidden_size ]) return decoder_output
def transformer_image_decoder(x, encoder_output, ed_attention_bias, hparams, name="transformer_dec"): """Transformer image decoder over inputs with local attention. Args: x: Tensor of shape [batch, ...], and whose size is batch * height * width * hparams.num_channels * hparams.hidden_size. encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: Tensor of shape [batch, height, width * hparams.num_channels, hparams.hidden_size]. """ with tf.variable_scope(name): batch_size = common_layers.shape_list(x)[0] targets = tf.reshape(x, [ batch_size, hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) decoder_input, _, _ = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [ decoder_output_shape[0], hparams.img_len, hparams.img_len * hparams.num_channels, hparams.hidden_size ]) return decoder_output
def transformer_image_decoder(targets, encoder_output, ed_attention_bias, hparams, name=None): """Transformer image decoder over targets with local attention. Args: targets: Tensor of shape [batch, ...], and whose size is batch * height * width * hparams.num_channels * hparams.hidden_size. encoder_output: Tensor of shape [batch, length_kv, hparams.hidden_size]. ed_attention_bias: Tensor which broadcasts with shape [batch, hparams.num_heads, length_q, length_kv]. Encoder-decoder attention bias. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: Tensor of shape [batch, height, width * hparams.num_channels, hparams.hidden_size]. """ with tf.variable_scope(name, default_name="transformer_dec"): batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size]) decoder_input, _, _ = cia.prepare_decoder(targets, hparams) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, encoder_decoder_attention_bias=ed_attention_bias, name="decoder") decoder_output = tf.reshape(decoder_output, [batch_size, hparams.img_len, hparams.img_len * hparams.num_channels, hparams.hidden_size]) return decoder_output
def body(self, features): hparams = copy.copy(self._hparams) targets = features["targets"] if (hparams.likelihood == cia.DistributionType.DMOL and (hparams.modality["targets"] != modalities.ImageChannelBottomIdentityModality or hparams.num_channels != 1)): raise ValueError("When using DMOL for the likelihood,modality['targets'] " "must be ImageChannelBottomIdentityModality and " "num_channels must be 1.") if (not tf.get_variable_scope().reuse and hparams.mode != tf.contrib.learn.ModeKeys.INFER and hparams.modality["targets"] != modalities.ImageChannelBottomIdentityModality): tf.summary.image("targets", tf.to_float(targets), max_outputs=1) # Extra losses list if we want to use moe. losses = [] # Prepare decoder inputs and bias. decoder_input, rows, cols = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.unconditional: inputs = features["inputs"] decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, losses=losses, name="decoder") output = cia.create_output(decoder_output, rows, cols, targets, hparams) if losses: return output, {"extra_loss": tf.add_n(losses)} else: return output